Skip to content

Commit

Permalink
Migrate/jit/matmul tiling 2d (#1472)
Browse files Browse the repository at this point in the history
* refactor matmul files

* wip refactor matmul

* everything is memco

* support local arrays

* advancing tiling2d

* advancing tiling2d

* advancing tiling2d

* tiling2d finished but buggy

* configurable unrolling

* not bugged

* fails on unroll

* stupid break

* tiling2d no assumption works

* clippy

* bounds check as bool

* lhs rhs as enum

* tiling 2d major refactor

* remove assign vec4

* variable declarations above loops

* fmt

* clippy

* Fix autotune + unroll

* move val

* clippy

* fmt

---------

Co-authored-by: nathaniel <nathaniel.simard.42@gmail.com>
  • Loading branch information
louisfd and nathanielsimard committed Mar 22, 2024
1 parent 0a8a3cc commit dd699a9
Show file tree
Hide file tree
Showing 29 changed files with 1,296 additions and 330 deletions.
2 changes: 2 additions & 0 deletions crates/burn-jit/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
fn compile(shader: gpu::ComputeShader) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: gpu::Elem) -> usize;
/// The maximal size of a shared memory
fn max_shared_memory_size() -> usize;
}
12 changes: 12 additions & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,15 @@ impl Loop {
parent_scope.register(Branch::Loop(op));
}
}

#[allow(missing_docs)]
pub struct UnrolledRangeLoop;

impl UnrolledRangeLoop {
/// Registers an unrolled range loop to the given scope.
pub fn register<F: Fn(Variable, &mut Scope)>(scope: &mut Scope, start: u32, end: u32, func: F) {
for i in start..end {
func(i.into(), scope);
}
}
}
21 changes: 20 additions & 1 deletion crates/burn-jit/src/codegen/dialect/gpu/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,17 @@ macro_rules! gpu {
gpu!(unary $input, $out)
));
};
// out = vec4(a, b, c, d)
($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => {
let i = $scope.zero(Elem::UInt);
gpu!($scope, $out[i] = $a);
gpu!($scope, i = i + 1u32);
gpu!($scope, $out[i] = $b);
gpu!($scope, i = i + 1u32);
gpu!($scope, $out[i] = $c);
gpu!($scope, i = i + 1u32);
gpu!($scope, $out[i] = $d);
};
// out = input
($scope:expr, $out:ident = $input:ident) => {
gpu!($scope, $out = cast($input))
Expand Down Expand Up @@ -326,10 +337,18 @@ macro_rules! gpu {
out: $out.into(),
});
};
// range(start, end).for_each(|scope| { ... })
// range(start, end).for_each(|i, scope| { ... })
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
};
// range(start, end, unroll).for_each(|i, scope| { ... })
($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => {
if $unroll {
$crate::codegen::dialect::gpu::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), $arg);
} else {
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
}
};
// loop(|scope| { ... })
($scope:expr, loop($arg:expr)) => {
$crate::codegen::dialect::gpu::Loop::register($scope, $arg);
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub enum Operator {
Tanh(UnaryOperator),
Powf(BinaryOperator),
Sqrt(UnaryOperator),
Ceil(UnaryOperator),
Erf(UnaryOperator),
Recip(UnaryOperator),
Equal(BinaryOperator),
Expand Down
26 changes: 21 additions & 5 deletions crates/burn-jit/src/codegen/dialect/gpu/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ pub struct Scope {
pub depth: u8,
pub operations: Vec<Operation>,
locals: Vec<Variable>,
shared: Vec<Variable>,
shared_memories: Vec<Variable>,
local_arrays: Vec<Variable>,
reads_global: Vec<(Variable, ReadingStrategy, Variable)>,
index_offset_with_output_layout_position: Vec<usize>,
writes_global: Vec<(Variable, Variable)>,
Expand Down Expand Up @@ -48,7 +49,8 @@ impl Scope {
depth: 0,
operations: Vec::new(),
locals: Vec::new(),
shared: Vec::new(),
local_arrays: Vec::new(),
shared_memories: Vec::new(),
reads_global: Vec::new(),
index_offset_with_output_layout_position: Vec::new(),
writes_global: Vec::new(),
Expand Down Expand Up @@ -213,7 +215,8 @@ impl Scope {
depth: self.depth + 1,
operations: Vec::new(),
locals: Vec::new(),
shared: Vec::new(),
shared_memories: Vec::new(),
local_arrays: Vec::new(),
reads_global: Vec::new(),
index_offset_with_output_layout_position: Vec::new(),
writes_global: Vec::new(),
Expand Down Expand Up @@ -308,7 +311,11 @@ impl Scope {
}

fn new_shared_index(&self) -> u16 {
self.shared.len() as u16
self.shared_memories.len() as u16
}

fn new_local_array_index(&self) -> u16 {
self.local_arrays.len() as u16
}

fn read_input_strategy(
Expand Down Expand Up @@ -339,7 +346,16 @@ impl Scope {
let item = item.into();
let index = self.new_shared_index();
let shared_memory = Variable::SharedMemory(index, item, shared_memory_size);
self.shared.push(shared_memory);
self.shared_memories.push(shared_memory);
shared_memory
}

/// Create a local array of the given [item type](Item).
pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> Variable {
let item = item.into();
let index = self.new_local_array_index();
let local_array = Variable::LocalArray(index, item, self.depth, array_size);
self.local_arrays.push(local_array);
local_array
}
}
3 changes: 3 additions & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub enum Variable {
LocalScalar(u16, Elem, u8),
ConstantScalar(f64, Elem),
SharedMemory(u16, Item, u32),
LocalArray(u16, Item, u8, u32),
Id,
LocalInvocationIndex,
LocalInvocationIdX,
Expand Down Expand Up @@ -41,6 +42,7 @@ impl Variable {
Variable::GlobalOutputArray(idx, _) => Some(*idx),
Variable::ConstantScalar(_, _) => None,
Variable::SharedMemory(idx, _, _) => Some(*idx),
Variable::LocalArray(idx, _, _, _) => Some(*idx),
Variable::Id => None,
Variable::LocalInvocationIndex => None,
Variable::LocalInvocationIdX => None,
Expand Down Expand Up @@ -70,6 +72,7 @@ impl Variable {
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
Variable::SharedMemory(_, item, _) => *item,
Variable::LocalArray(_, item, _, _) => *item,
Variable::Id => Item::Scalar(Elem::UInt),
Variable::Rank => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt),
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl Operator {
Operator::Tanh(op) => Operator::Tanh(op.vectorize(vectorization)),
Operator::Powf(op) => Operator::Powf(op.vectorize(vectorization)),
Operator::Sqrt(op) => Operator::Sqrt(op.vectorize(vectorization)),
Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)),
Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)),
Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)),
Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)),
Expand Down Expand Up @@ -130,6 +131,12 @@ impl Variable {
item.vectorize(vectorize),
item.vectorized_size(vectorize, *size),
),
Variable::LocalArray(index, item, name, size) => Variable::LocalArray(
*index,
item.vectorize(vectorize),
*name,
item.vectorized_size(vectorize, *size),
),
Variable::ConstantScalar(_, _) => *self,
Variable::GlobalScalar(_, _) => *self,
Variable::Id => *self,
Expand Down
5 changes: 5 additions & 0 deletions crates/burn-jit/src/fusion/tracing/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Ceil(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Log(op) => mark_unary(
op,
&mut local_tensor_ids_input,
Expand Down
137 changes: 125 additions & 12 deletions crates/burn-jit/src/kernel/matmul/base.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,98 @@
use crate::{tensor::JitTensor, JitElement, Runtime};
use std::cmp::{max, min};

use burn_tensor::Shape;

use crate::{compute::WorkGroup, tensor::JitTensor, Compiler, JitElement, Runtime};

use super::{
init_matmul_output, matmul_autotune, matmul_mem_coalescing,
unpadded::matmul_tiling_2d_unpadded, vec4::matmul_tiling_2d_vec4,
init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, matmul_tiling_2d_padded,
};

#[derive(Debug, Clone)]
/// Tiling 2D parameters
pub struct Tiling2dConfig {
/// Number of invocations in x
pub grid_x: usize,
/// Number of invocations in y
pub grid_y: usize,
/// Block size along dimension of lhs
pub block_size_m: usize,
/// Block size along common dimension
pub block_size_k: usize,
/// Block size along dimension of rhs
pub block_size_n: usize,
/// Tile size along dimension of lhs
pub tile_size_m: usize,
/// Tile size along dimension of rhs
pub tile_size_n: usize,
}

impl Tiling2dConfig {
#[allow(unused)]
fn new<R: Runtime>(
grid_x: usize,
grid_y: usize,
block_size_m: usize,
block_size_k: usize,
block_size_n: usize,
tile_size_m: usize,
tile_size_n: usize,
) -> Self {
assert!(grid_x == f32::ceil(block_size_m as f32 / tile_size_m as f32) as usize);
assert!(grid_y == f32::ceil(block_size_n as f32 / tile_size_n as f32) as usize);
assert!(
block_size_k <= min(block_size_m, block_size_n),
"Not enough invocations to fill shared memory"
);
assert!(
block_size_k * max(block_size_m, block_size_n)
<= <R::Compiler as Compiler>::max_shared_memory_size(),
"Shared memory limit will be busted. "
);
assert!(
block_size_m % tile_size_m == 0 && block_size_n % tile_size_n == 0,
"Tile size must divide block size in m and n dimensions"
);
Self {
grid_x,
grid_y,
block_size_m,
block_size_k,
block_size_n,
tile_size_m,
tile_size_n,
}
}
}

impl Default for Tiling2dConfig {
fn default() -> Self {
Self {
grid_x: 16,
grid_y: 16,
block_size_m: 64,
block_size_k: 32,
block_size_n: 64,
tile_size_m: 4,
tile_size_n: 4,
}
}
}

/// The strategy to be used when launching a matmul kernel.
#[derive(Default)]
pub enum MatmulStrategy {
/// A simple kernel will be used with memory coalescing optimization.
Simple {
/// Grad size x
/// Number of invocations in x
grid_x: usize,
/// Grad size y
/// Number of invocations in y
grid_y: usize,
},
/// A tiling 2d kernel will be used, with support for any matrix size without padding.
Tiling2d,
Tiling2d(Tiling2dConfig),
/// A tiling 2d kernel will be used, with support for any matrix size with padding.
Tiling2dPadded,
Tiling2dPadded(Tiling2dConfig),
#[cfg(feature = "autotune")]
/// Using autotune to chose the best kernel based on runtime information.
#[default]
Expand All @@ -42,17 +116,56 @@ pub fn matmul<R: Runtime, E: JitElement, const D: usize>(
match strategy {
MatmulStrategy::Simple { grid_x, grid_y } => {
let out = init_matmul_output(&lhs, &rhs);
matmul_mem_coalescing(lhs, rhs, out, grid_x, grid_y)
matmul_simple(lhs, rhs, out, grid_x, grid_y)
}
MatmulStrategy::Tiling2d => {
MatmulStrategy::Tiling2d(config) => {
let out = init_matmul_output(&lhs, &rhs);
matmul_tiling_2d_unpadded(lhs, rhs, out)
matmul_tiling_2d(lhs, rhs, out, config)
}
MatmulStrategy::Tiling2dPadded => {
MatmulStrategy::Tiling2dPadded(config) => {
let out = init_matmul_output(&lhs, &rhs);
matmul_tiling_2d_vec4(lhs, rhs, out)
matmul_tiling_2d_padded(lhs, rhs, out, config)
}
#[cfg(feature = "autotune")]
MatmulStrategy::Autotune => matmul_autotune(lhs, rhs),
}
}

pub(crate) fn simple_launch_options<const D: usize>(
lhs_shape: &Shape<D>,
rhs_shape: &Shape<D>,
output_shape: &Shape<D>,
workgroup_size_x: usize,
workgroup_size_y: usize,
) -> WorkGroup {
let num_rows = lhs_shape.dims[D - 2];
let num_cols = rhs_shape.dims[D - 1];

// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
let mut num_iter = 1;
for i in 0..D - 2 {
num_iter *= output_shape.dims[i];
}

WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
}

pub(crate) fn tiling2d_launch_options<const D: usize>(
output_shape: &Shape<D>,
config: Tiling2dConfig,
) -> WorkGroup {
let num_rows = output_shape.dims[D - 2];
let num_cols = output_shape.dims[D - 1];

// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
let mut num_iter = 1;
for i in 0..D - 2 {
num_iter *= output_shape.dims[i];
}

WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
}
15 changes: 12 additions & 3 deletions crates/burn-jit/src/kernel/matmul/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
mod base;
mod mem_coalescing;
mod simple;
mod tiling2d;
mod tiling2d_shader;
mod tune;

/// Contains utilitary for matmul operation
pub mod utils;

pub use base::*;
pub use mem_coalescing::*;
pub use tiling2d::*;
pub use simple::*;
pub use tune::*;
pub use utils::*;

#[cfg(feature = "export_tests")]
#[allow(missing_docs)]
pub mod padding;

#[cfg(not(feature = "export_tests"))]
mod padding;

pub use tiling2d::*;
Loading

0 comments on commit dd699a9

Please sign in to comment.