Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate/jit/matmul tiling 2d #1472

Merged
merged 28 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tiling2d no assumption works
  • Loading branch information
louisfd committed Mar 14, 2024
commit a0bb940af6a970ff70db00528e834767bd5a4c6c
1 change: 1 addition & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub enum Operator {
Tanh(UnaryOperator),
Powf(BinaryOperator),
Sqrt(UnaryOperator),
Ceil(UnaryOperator),
Erf(UnaryOperator),
Recip(UnaryOperator),
Equal(BinaryOperator),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl Operator {
Operator::Tanh(op) => Operator::Tanh(op.vectorize(vectorization)),
Operator::Powf(op) => Operator::Powf(op.vectorize(vectorization)),
Operator::Sqrt(op) => Operator::Sqrt(op.vectorize(vectorization)),
Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)),
Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)),
Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)),
Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)),
Expand Down
5 changes: 5 additions & 0 deletions crates/burn-jit/src/fusion/tracing/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Ceil(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Log(op) => mark_unary(
op,
&mut local_tensor_ids_input,
Expand Down
11 changes: 9 additions & 2 deletions crates/burn-jit/src/kernel/matmul/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@ use burn_tensor::Shape;
use crate::{compute::WorkGroup, tensor::JitTensor, Compiler, JitElement, Runtime};

use super::{
init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d,
tiling2d_padded::matmul_tiling_2d_padded,
init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, matmul_tiling_2d_padded,
};

#[derive(Debug, Clone)]
pub(crate) enum Tiling2DAssumption {
// Input shapes are divisible by their corresponding block sizes
Round,
// Bounds must be checked
None,
}

#[derive(Debug, Clone)]
/// Tiling 2D parameters
pub struct Tiling2dConfig {
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/matmul/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod base;
mod simple;
mod tiling2d;
mod tiling2d_padded;
mod tiling2d_shader;
mod tune;

/// Contains utilitary for matmul operation
Expand All @@ -20,4 +20,4 @@ pub mod padding;
mod padding;

pub use tiling2d::*;
pub use tiling2d_padded::*;
use tiling2d_shader::*;
219 changes: 86 additions & 133 deletions crates/burn-jit/src/kernel/matmul/tiling2d.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use burn_tensor::Element;
use burn_tensor::{Element, Shape};

use crate::{
codegen::{
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
Execution, InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
gpu::{gpu, BinaryOperator, Branch, Elem, IndexOffsetGlobalWithLayout, Scope, Variable},
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate},
tensor::JitTensor,
Runtime,
};
use std::marker::PhantomData;

use super::{tiling2d_launch_options, Tiling2dConfig};
use super::{
padding::{crop, pad_round, PaddingOutput},
shape_out, tiling2d_launch_options, MatmulTiling2dShader, Tiling2DAssumption, Tiling2dConfig,
};

#[derive(new, Debug)]
struct MatmulTiling2d<E: JitElement> {
Expand All @@ -23,136 +25,10 @@ struct MatmulTiling2d<E: JitElement> {
#[derive(new, Debug)]
struct MatmulTiling2dEagerKernel<R: Runtime> {
config: Tiling2dConfig,
assumption: Tiling2DAssumption,
_runtime: PhantomData<R>,
}

struct MatmulTiling2dShader {
variables: BinaryOperator,
block_size: usize,
}

impl MatmulTiling2dShader {
fn expand(self, scope: &mut Scope) {
// Define out global variables.
let local_idx = Variable::LocalInvocationIndex;
let batch = Variable::GlobalInvocationIdZ;
let rank = Variable::Rank;
let block_size: Variable = self.block_size.into();

// Extract tensor variables.
let lhs = self.variables.lhs;
let rhs = self.variables.rhs;
let out = self.variables.out;

// Define where we have to work on the current matrix.
let tmp_index = scope.create_local(Elem::UInt);
let batch_dims = scope.create_local(Elem::UInt);
let row = scope.create_local(Elem::UInt);
let col = scope.create_local(Elem::UInt);

// Row position.
gpu!(scope, tmp_index = local_idx / block_size);
gpu!(scope, row = block_size * Variable::WorkgroupIdX);
gpu!(scope, row = row + tmp_index);

// Col position.
gpu!(scope, tmp_index = local_idx % block_size);
gpu!(scope, col = block_size * Variable::WorkgroupIdY);
gpu!(scope, col = col + tmp_index);

// Batch position.
gpu!(scope, batch_dims = rank - 2u32);

// Define the matrix size.
let n_rows = scope.create_local(Elem::UInt);
let n_cols = scope.create_local(Elem::UInt);
let k = scope.create_local(Elem::UInt);

// Number of rows.
gpu!(scope, n_rows = shape(out, batch_dims));

// Number of cols.
gpu!(scope, tmp_index = batch_dims + 1u32);
gpu!(scope, n_cols = shape(out, tmp_index));

// The dimension that is going to be squashed.
gpu!(scope, k = shape(lhs, tmp_index));

// Check if there is some work to be done.
let should_stop = scope.create_local(Elem::Bool);
gpu!(scope, should_stop = row >= n_rows);
gpu!(scope, if (should_stop).then(|scope| {
scope.register(Branch::Return);
}));

gpu!(scope, should_stop = col >= n_cols);
gpu!(scope, if (should_stop).then(|scope| {
scope.register(Branch::Return);
}));

// Calculate the batch offset.
let offset_lhs = scope.zero(Elem::UInt);
let offset_rhs = scope.zero(Elem::UInt);
let offset_output = scope.create_local(Elem::UInt);

// Batch offset for the output.
gpu!(scope, offset_output = n_rows * n_cols);
gpu!(scope, offset_output = offset_output * batch);

// Batch offset for the lhs & rhs matrices.
IndexOffsetGlobalWithLayout {
tensors: vec![lhs, rhs],
indexes: vec![offset_lhs, offset_rhs],
layout: out,
index_ref: offset_output,
dim_start: 0u32.into(),
dim_end: batch_dims,
}
.expand(scope);

// Calculate the dot product (row X col).
let sum = scope.create_local(out.item());

// Initialize the sum to zero.
let zero: Variable = 0f32.into();
gpu!(scope, sum = zero);

// Loop over the k dimension.
gpu!(
scope,
range(0u32, k).for_each(|i, scope| {
let lhs_index = scope.create_local(Elem::UInt);
let rhs_index = scope.create_local(Elem::UInt);

let lhs_value = scope.create_local(lhs.item());
let rhs_value = scope.create_local(rhs.item());
let out_value = scope.create_local(out.item());

gpu!(scope, lhs_index = row * k);
gpu!(scope, lhs_index = lhs_index + i);
gpu!(scope, lhs_index = lhs_index + offset_lhs);

gpu!(scope, rhs_index = i * n_cols);
gpu!(scope, rhs_index = rhs_index + col);
gpu!(scope, rhs_index = rhs_index + offset_rhs);

gpu!(scope, lhs_value = lhs[lhs_index]);
gpu!(scope, rhs_value = rhs[rhs_index]);

gpu!(scope, out_value = lhs_value * rhs_value);
gpu!(scope, sum += out_value);
})
);

let out_index = scope.create_local(Elem::UInt);

gpu!(scope, out_index = row * n_cols);
gpu!(scope, out_index += col);
gpu!(scope, out_index += offset_output);
gpu!(scope, out[out_index] = sum);
}
}

impl<R: Runtime> DynamicKernelSource for MatmulTiling2dEagerKernel<R> {
fn source(&self) -> SourceTemplate {
let mut scope = gpu::Scope::root();
Expand All @@ -164,7 +40,9 @@ impl<R: Runtime> DynamicKernelSource for MatmulTiling2dEagerKernel<R> {

MatmulTiling2dShader {
variables: gpu::BinaryOperator { lhs, rhs, out },
block_size: self.config.grid_x, // TODO
config: self.config.clone(),
assumption: self.assumption.clone(),
unroll: false,
}
.expand(&mut scope);

Expand Down Expand Up @@ -198,9 +76,10 @@ impl<R: Runtime> DynamicKernelSource for MatmulTiling2dEagerKernel<R> {

fn id(&self) -> String {
format!(
"{:?}config={:?}",
"{:?}config={:?}assumption={:?}",
core::any::TypeId::of::<Self>(),
self.config,
self.assumption
)
}
}
Expand All @@ -213,7 +92,9 @@ pub fn matmul_tiling_2d<R: Runtime, E: JitElement + Element, const D: usize>(
out: JitTensor<R, E, D>,
config: Tiling2dConfig,
) -> JitTensor<R, E, D> {
let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone());
let assumption = check_assumption(&lhs.shape, &rhs.shape, &config);

let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone(), assumption);
let client = lhs.client.clone();

let lhs = match lhs.batch_swapped_with_row_col() {
Expand All @@ -237,3 +118,75 @@ pub fn matmul_tiling_2d<R: Runtime, E: JitElement + Element, const D: usize>(

out
}

/// Matrix multiplication using tiling 2d algorithm with padding needed
pub fn matmul_tiling_2d_padded<R: Runtime, E: JitElement + Element, const D: usize>(
lhs: JitTensor<R, E, D>,
rhs: JitTensor<R, E, D>,
out: JitTensor<R, E, D>,
config: Tiling2dConfig,
) -> JitTensor<R, E, D> {
let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone(), Tiling2DAssumption::Round);
let client = lhs.client.clone();

// A tensor may need to be padded, in which case it will implicitly become contiguous
// If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim.
// If batches were swapped among themselves, or if the last two dims are transposed, the underlying
// kernel handles it without needing to turn it into contiguous.
let round_lhs = pad_round::<R, E, D>(lhs, config.block_size_m, config.block_size_k);
let lhs = match round_lhs {
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
into_contiguous(tensor)
}
_ => round_lhs.into_tensor(),
};
let round_rhs = pad_round::<R, E, D>(rhs, config.block_size_k, config.block_size_n);
let rhs = match round_rhs {
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
into_contiguous(tensor)
}
_ => round_rhs.into_tensor(),
};

let rounded_output_shape = shape_out(&lhs, &rhs);

let num_elems = rounded_output_shape.num_elements();
let buffer = client.empty(num_elems * core::mem::size_of::<E>());
let rounded_output = JitTensor::new(
rhs.client.clone(),
rhs.device.clone(),
rounded_output_shape.clone(),
buffer,
);

Execution::start(kernel, client)
.inputs(&[
EagerHandle::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
])
.outputs(&[EagerHandle::new(
&rounded_output.handle,
&rounded_output.strides,
&rounded_output.shape.dims,
)])
.execute(WorkgroupLaunch::Custom(tiling2d_launch_options(
&rounded_output.shape,
config,
)));

crop(rounded_output, out)
}

fn check_assumption<const D: usize>(
lhs_shape: &Shape<D>,
rhs_shape: &Shape<D>,
config: &Tiling2dConfig,
) -> Tiling2DAssumption {
let m_divisible = lhs_shape.dims[D - 2] % config.block_size_m == 0;
let k_divisible = lhs_shape.dims[D - 1] % config.block_size_k == 0;
let n_divisible = rhs_shape.dims[D - 1] % config.block_size_n == 0;
match m_divisible && k_divisible && n_divisible {
true => Tiling2DAssumption::Round,
false => Tiling2DAssumption::None,
}
}
Loading
Loading