Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate/jit/matmul tiling 2d #1472

Merged
merged 28 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
refactor matmul files
  • Loading branch information
louisfd committed Mar 11, 2024
commit 83fd099d67970d720bcb31d5438d398162ae690d
105 changes: 98 additions & 7 deletions crates/burn-jit/src/kernel/matmul/base.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
use crate::{tensor::JitTensor, JitElement, Runtime};
use burn_compute::server::Handle;
use burn_tensor::Shape;

use crate::{
compute::{DynamicKernel, WorkGroup},
kernel::{build_info, into_contiguous, DynamicKernelSource},
ops::numeric::empty_device,
tensor::JitTensor,
JitElement, Runtime,
};

use super::{
init_matmul_output, matmul_autotune, matmul_mem_coalescing,
unpadded::matmul_tiling_2d_unpadded, vec4::matmul_tiling_2d_vec4,
init_matmul_output, matmul_autotune, matmul_simple,
padding::{crop, pad_round, PaddingOutput},
shape_out,
tiling2d::matmul_tiling_2d,
tiling2d_padded::matmul_tiling_2d_padded,
};

/// The strategy to be used when launching a matmul kernel.
Expand All @@ -25,7 +37,6 @@ pub enum MatmulStrategy {
Autotune,
}

#[cfg(feature = "autotune")]
#[cfg(not(feature = "autotune"))]
impl Default for MatmulStrategy {
fn default() -> Self {
Expand All @@ -42,17 +53,97 @@ pub fn matmul<R: Runtime, E: JitElement, const D: usize>(
match strategy {
MatmulStrategy::Simple { grid_x, grid_y } => {
let out = init_matmul_output(&lhs, &rhs);
matmul_mem_coalescing(lhs, rhs, out, grid_x, grid_y)
matmul_simple(lhs, rhs, out, grid_x, grid_y)
}
MatmulStrategy::Tiling2d => {
let out = init_matmul_output(&lhs, &rhs);
matmul_tiling_2d_unpadded(lhs, rhs, out)
matmul_tiling_2d(lhs, rhs, out)
}
MatmulStrategy::Tiling2dPadded => {
let out = init_matmul_output(&lhs, &rhs);
matmul_tiling_2d_vec4(lhs, rhs, out)
matmul_tiling_2d_padded(lhs, rhs, out)
}
#[cfg(feature = "autotune")]
MatmulStrategy::Autotune => matmul_autotune(lhs, rhs),
}
}

pub(crate) const B_M: usize = 64;
pub(crate) const B_N: usize = 64;
pub(crate) const B_K: usize = 32;
pub(crate) const WORKGROUP_SIZE: usize = 16;

pub(super) fn make_workgroup<const D: usize>(output_shape: &Shape<D>) -> WorkGroup {
let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32;
let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32;
let mut num_blocks_z = 1;
for i in 0..D - 2 {
num_blocks_z *= output_shape.dims[i];
}

WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32)
}

pub(super) fn make_info_handle<R: Runtime, E: JitElement, const D: usize>(
lhs: &JitTensor<R, E, D>,
rhs: &JitTensor<R, E, D>,
output: &JitTensor<R, E, D>,
) -> Handle<R::Server> {
let info = build_info(&[lhs, rhs, output]);
rhs.client.create(bytemuck::cast_slice(&info))
}

#[allow(clippy::too_many_arguments)]
pub(super) fn matmul_tiling_2d_launch<
R: Runtime,
E: JitElement,
const D: usize,
K: DynamicKernelSource + 'static,
>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
output: JitTensor<R, E, D>,
kernel: K,
) -> JitTensor<R, E, D> {
// A tensor may need to be padded, in which case it will implicitly become contiguous
// If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim.
// If batches were swapped among themselves, or if the last two dims are transposed, the underlying
// kernel handles it without needing to turn it into contiguous.
let round_lhs = pad_round::<R, E, D>(lhs, B_M, B_K);
let lhs = match round_lhs {
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
into_contiguous(tensor)
}
_ => round_lhs.into_tensor(),
};
let round_rhs = pad_round::<R, E, D>(rhs, B_K, B_N);
let rhs = match round_rhs {
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
into_contiguous(tensor)
}
_ => round_rhs.into_tensor(),
};

let rounded_output_shape = shape_out(&lhs, &rhs);

let rounded_output = empty_device(
rhs.client.clone(),
rhs.device.clone(),
rounded_output_shape.clone(),
);

let workgroup = make_workgroup(&rounded_output_shape);
let info_handle = make_info_handle(&lhs, &rhs, &rounded_output);

lhs.client.execute(
Box::new(DynamicKernel::new(kernel, workgroup)),
&[
&lhs.handle,
&rhs.handle,
&rounded_output.handle,
&info_handle,
],
);

crop(rounded_output, output)
}
18 changes: 14 additions & 4 deletions crates/burn-jit/src/kernel/matmul/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
mod base;
mod mem_coalescing;
mod tiling2d;
mod simple;
mod tune;

/// Contains utilitary for matmul operation
pub mod utils;

pub use base::*;
pub use mem_coalescing::*;
pub use tiling2d::*;
pub use simple::*;
pub use tune::*;
pub use utils::*;

#[cfg(feature = "export_tests")]
#[allow(missing_docs)]
pub mod padding;

#[cfg(not(feature = "export_tests"))]
mod padding;

pub mod tiling2d;
pub mod tiling2d_padded;
pub use tiling2d::*;
pub use tiling2d_padded::*;
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ pub fn matmul_mem_coalescing_default<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
matmul_mem_coalescing::<R, E, D>(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT)
matmul_simple::<R, E, D>(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT)
}

/// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes
pub fn matmul_mem_coalescing<R: Runtime, E: JitElement, const D: usize>(
pub fn matmul_simple<R: Runtime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZ

kernel_wgsl!(
MatmulTiling2DUnpaddedRaw,
"../../../template/matmul/blocktiling_2d/unpadded.wgsl"
"../../template/matmul/blocktiling_2d/unpadded.wgsl"
);

#[derive(new, Debug)]
Expand Down Expand Up @@ -45,7 +45,7 @@ impl<E: JitElement> DynamicKernelSource for MatmulTiling2DUnpadded<E> {

/// Matrix multiplication using tiling 2d algorithm with
/// vec4 primitive on both lhs and rhs, with no padding needed
pub fn matmul_tiling_2d_unpadded<R: Runtime, E: JitElement + Element, const D: usize>(
pub fn matmul_tiling_2d<R: Runtime, E: JitElement + Element, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
Expand Down
91 changes: 0 additions & 91 deletions crates/burn-jit/src/kernel/matmul/tiling2d/base.rs

This file was deleted.

14 changes: 0 additions & 14 deletions crates/burn-jit/src/kernel/matmul/tiling2d/mod.rs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::marker::PhantomData;

kernel_wgsl!(
MatmulTiling2Dvec4Raw,
"../../../template/matmul/blocktiling_2d/vec4.wgsl"
"../../template/matmul/blocktiling_2d/vec4.wgsl"
);

#[derive(new, Debug)]
Expand Down Expand Up @@ -37,9 +37,7 @@ impl<E: JitElement> DynamicKernelSource for MatmulTiling2Dvec4<E> {
}
}

/// Matrix multiplication using tiling 2d algorithm with
/// vec4 primitive on both lhs and rhs
pub fn matmul_tiling_2d_vec4<R: Runtime, E: JitElement, const D: usize>(
pub fn matmul_tiling_2d_padded<R: Runtime, E: JitElement, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
Expand Down
47 changes: 15 additions & 32 deletions crates/burn-jit/src/kernel/matmul/tune/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,14 @@ impl<R: Runtime, E: JitElement + Element, const D: usize> AutotuneOperationSet<J
);

vec![
Box::new(MemoryCoalescingMatmulDefault::new(
Box::new(SimpleMatmul::new(lhs.clone(), rhs.clone(), out.clone())),
Box::new(SimpleMatmul16x16::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(MemoryCoalescingMatmulW16x16::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(Vec4TilingMatmulDefault::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(Vec4TilingMatmulUnpaddedDefault::new(
Box::new(Tiling2DMatmul::new(lhs.clone(), rhs.clone(), out.clone())),
Box::new(Tiling2DMatmulPadded::new(
lhs.clone(),
rhs.clone(),
out.clone(),
Expand All @@ -75,16 +67,10 @@ impl<R: Runtime, E: JitElement + Element, const D: usize> AutotuneOperationSet<J

fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
match fastest_index {
0 => Box::new(MemoryCoalescingMatmulDefault::new(
self.lhs, self.rhs, self.out,
)),
1 => Box::new(MemoryCoalescingMatmulW16x16::new(
self.lhs, self.rhs, self.out,
)),
2 => Box::new(Vec4TilingMatmulDefault::new(self.lhs, self.rhs, self.out)),
3 => Box::new(Vec4TilingMatmulUnpaddedDefault::new(
self.lhs, self.rhs, self.out,
)),
0 => Box::new(SimpleMatmul::new(self.lhs, self.rhs, self.out)),
1 => Box::new(SimpleMatmul16x16::new(self.lhs, self.rhs, self.out)),
2 => Box::new(Tiling2DMatmul::new(self.lhs, self.rhs, self.out)),
3 => Box::new(Tiling2DMatmulPadded::new(self.lhs, self.rhs, self.out)),
_ => panic!("Fastest index is out of bound"),
}
}
Expand Down Expand Up @@ -134,23 +120,20 @@ macro_rules! matmul_tune_ops {

// Potentially better for small matrices.
matmul_tune_ops!(
MemoryCoalescingMatmulDefault,
SimpleMatmul,
crate::kernel::matmul::matmul_mem_coalescing_default
);

// Potentially better for small matrices.
matmul_tune_ops!(MemoryCoalescingMatmulW16x16, |lhs, rhs, out| {
crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16)
matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| {
crate::kernel::matmul::matmul_simple(lhs, rhs, out, 16, 16)
});

// Probably the fastest when fixed sizes.
matmul_tune_ops!(
Vec4TilingMatmulDefault,
crate::kernel::matmul::vec4::matmul_tiling_2d_vec4
Tiling2DMatmulPadded,
crate::kernel::matmul::matmul_tiling_2d_padded
);

// Probably the fastest otherwise.
matmul_tune_ops!(
Vec4TilingMatmulUnpaddedDefault,
crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded
);
// Probably the fastest in the general case
matmul_tune_ops!(Tiling2DMatmul, crate::kernel::matmul::matmul_tiling_2d);