Skip to content

Commit

Permalink
[ORTModule] Symbolic Shape Support for Triton Codegen (#18317)
Browse files Browse the repository at this point in the history
Add symbolic shape support for Triton codegen for ORTModule.
  • Loading branch information
er3x3 committed Nov 13, 2023
1 parent 73ed34a commit 4a82030
Show file tree
Hide file tree
Showing 10 changed files with 324 additions and 120 deletions.
82 changes: 55 additions & 27 deletions orttraining/orttraining/python/training/ort_triton/_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,9 @@ class TritonCodegen(NodeVisitor):
Specialized codegen for Triton backend.
"""

def __init__(self):
super().__init__()

def codegen(self, node: IRNode, context: CodegenContext, code_buffer: CodeBuffer, indent: int):
func = getattr(self, node.__class__.__name__)
assert func is not None, "unimplemented node: %s" % node.__class__.__name__
assert func is not None, f"unimplemented node: {node.__class__.__name__}"
func(node, context, code_buffer, indent)

def _get_elementwise_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]:
Expand Down Expand Up @@ -125,18 +122,29 @@ def IONode(self, node: IONode, context: CodegenContext, code_buffer: CodeBuffer,
def _gen_kernel_signature(self, node: KernelNode, context: CodegenContext, code_buffer: CodeBuffer, indent: int):
is_reduction = node.offset_calc.is_reduction
space_indent = " " * indent
autotune_configs_str = ""
for config in node.offset_calc.autotune_configs.configs:
if is_reduction:
autotune_configs_str += (
f'{space_indent} triton.Config({{"XBLOCK": {config[0]}, "RBLOCK": {config[1]}}}, '
f"num_warps={config[2]}),\n"
)
else:
autotune_configs_str += (
f'{space_indent} triton.Config({{"XBLOCK": {config[0]}}}, num_warps={config[2]}),\n'
)
keys_str = '"xnumel", "rnumel"' if is_reduction else '"xnumel"'

if len(node.offset_calc.autotune_configs.configs) > 1:
autotune_configs_str = ""
for config in node.offset_calc.autotune_configs.configs:
if is_reduction:
autotune_configs_str += (
f'{space_indent} triton.Config({{"XBLOCK": {config[0]}, "RBLOCK": {config[1]}}}, '
f"num_warps={config[2]}),\n"
)
else:
autotune_configs_str += (
f'{space_indent} triton.Config({{"XBLOCK": {config[0]}}}, num_warps={config[2]}),\n'
)
keys_str = '"xnumel", "rnumel"' if is_reduction else '"xnumel"'
code_buffer += (
f"{space_indent}@triton.autotune(\n"
f"{space_indent} configs=[\n"
f"{autotune_configs_str}"
f"{space_indent} ],\n"
f"{space_indent} key=[{keys_str}],\n"
f"{space_indent})\n"
)

input_args = [context.get_variable_name(input.name) for input in node.inputs]
input_args_str = ", ".join(input_args)
if input_args_str:
Expand All @@ -158,12 +166,6 @@ def _gen_kernel_signature(self, node: KernelNode, context: CodegenContext, code_
)

code_buffer += (
f"{space_indent}@triton.autotune(\n"
f"{space_indent} configs=[\n"
f"{autotune_configs_str}"
f"{space_indent} ],\n"
f"{space_indent} key=[{keys_str}],\n"
f"{space_indent})\n"
f"{space_indent}@triton.jit\n"
f"{space_indent}def {node.name}({input_args_str}{output_args_str}{other_input_args}{blocks_str}):\n"
)
Expand All @@ -175,8 +177,10 @@ def ElementwiseKernelNode( # noqa: N802
offset_calc = node.offset_calc
indent += 4
space_indent = " " * indent
x_numel_str = str(offset_calc.x_numel)
if x_numel_str.isnumeric():
code_buffer += f"{space_indent}xnumel = {x_numel_str}\n"
code_buffer += (
f"{space_indent}xnumel = {offset_calc.x_numel}\n"
f"{space_indent}xoffset = tl.program_id(0) * XBLOCK\n"
f"{space_indent}xindex = xoffset + tl.arange(0, XBLOCK)\n"
)
Expand Down Expand Up @@ -207,9 +211,13 @@ def ReduceKernelNode( # noqa: N802
offset_calc = node.offset_calc
indent += 4
space_indent = " " * indent
x_numel_str = str(offset_calc.x_numel)
if x_numel_str.isnumeric():
code_buffer += f"{space_indent}xnumel = {x_numel_str}\n"
r_numel_str = str(offset_calc.r_numel)
if r_numel_str.isnumeric():
code_buffer += f"{space_indent}rnumel = {r_numel_str}\n"
code_buffer += (
f"{space_indent}xnumel = {offset_calc.x_numel}\n"
f"{space_indent}rnumel = {offset_calc.r_numel}\n"
f"{space_indent}xoffset = tl.program_id(0) * XBLOCK\n"
f"{space_indent}xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n"
f"{space_indent}rbase = tl.arange(0, RBLOCK)[None, :]\n"
Expand Down Expand Up @@ -444,6 +452,13 @@ def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: Cod
indent += 4
space_indent = " " * indent

seen_symbolic_shape = set()
for input in node.inputs:
for idx, dim in enumerate(input.shape):
if dim.is_symbol and dim not in seen_symbolic_shape:
code_buffer += f"{space_indent}{dim} = {context.get_variable_name(input.name)}.size()[{idx}]\n"
seen_symbolic_shape.add(dim)

if node.has_dropout:
code_buffer += (
f'{space_indent}seed_cuda = torch.randint(2**31, size=(), dtype=torch.int64, device="cuda")\n\n'
Expand All @@ -470,18 +485,31 @@ def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: Cod
if kernel_node.has_dropout:
kernel_args_str += ", seed_cuda"

# Support symbolic shape if any.
symbolic_shape_args_str = ", ".join(kernel_node.symbolic_shape_variables)
if symbolic_shape_args_str:
kernel_args_str += f", {symbolic_shape_args_str}"

block_str = ""
if len(kernel_node.offset_calc.autotune_configs.configs) == 1:
config = kernel_node.offset_calc.autotune_configs.configs[0]
if kernel_node.offset_calc.is_reduction:
block_str = f", XBLOCK={config[0]}, RBLOCK={config[1]}, num_warps={config[2]}"
else:
block_str = f", XBLOCK={config[0]}, num_warps={config[2]}"

if isinstance(kernel_node, ReduceKernelNode):
code_buffer += (
f"{space_indent}x_numel = {kernel_node.offset_calc.x_numel}\n"
f"{space_indent}r_numel = {kernel_node.offset_calc.r_numel}\n"
f'{space_indent}grid = lambda meta: (triton.cdiv(x_numel, meta["XBLOCK"]),)\n'
f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, x_numel, r_numel)\n"
f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, x_numel, r_numel{block_str})\n"
)
else:
code_buffer += (
f"{space_indent}n_elements = {kernel_node.offset_calc.x_numel}\n"
f'{space_indent}grid = lambda meta: (triton.cdiv(n_elements, meta["XBLOCK"]),)\n'
f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, n_elements)\n"
f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, n_elements{block_str})\n"
)

for name in node.cross_kernel_args_to_delete[idx]:
Expand Down
116 changes: 95 additions & 21 deletions orttraining/orttraining/python/training/ort_triton/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import sympy
from onnx import GraphProto, NodeProto, TensorProto

from ._sympy_utils import parse_shape
from ._sympy_utils import extract_shape_from_symbol
from ._utils import get_attribute, get_reduce_info, next_power_of_2

_SPECIAL_FLOATS: List[str] = ["inf", "-inf"]


class CodegenContext:
"""
Expand All @@ -28,7 +30,8 @@ def get_variable_name(self, name: str) -> str:
# For some operators such as data load/store, we need an internal variable name inside the kernel function.
def get_internal_variable_name(self, name: str) -> str:
var_name = self._var_map[name]
return self._var_map[var_name] if var_name in self._var_map else var_name
var_name = self._var_map[var_name] if var_name in self._var_map else var_name
return f'float("{var_name}")' if var_name in _SPECIAL_FLOATS else var_name


class CodeBuffer:
Expand All @@ -49,14 +52,38 @@ def codegen(self, node: Any, context: CodegenContext, code_buffer: CodeBuffer, i
pass


class SymbolicDSU:
"""
A 'disjoint set union' to merge symbolics so that we use less variables in the generated code.
When handling shape inference for elementwise Ops, if two symbols are not equal and they are not 1, we merge them.
"""

def __init__(self):
self._dsu: Dict[sympy.Expr, sympy.Expr] = {}

def find(self, symbolic: sympy.Expr) -> sympy.Expr:
if symbolic not in self._dsu:
self._dsu[symbolic] = symbolic
return symbolic
if symbolic == self._dsu[symbolic]:
return symbolic
self._dsu[symbolic] = self.find(self._dsu[symbolic])
return self._dsu[symbolic]

def union(self, symbolic: sympy.Expr, other_symbolic: sympy.Expr):
root = self.find(symbolic)
other_root = self.find(other_symbolic)
self._dsu[other_root] = root


class TensorInfo:
"""
Represent a input/output tensor of a node.
"""

def __init__(self, dtype: TensorProto.DataType, shape: List[Any]):
def __init__(self, dtype: TensorProto.DataType, shape: List[sympy.Expr]):
self._dtype: TensorProto.DataType = dtype
self._shape: List[sympy.Expr] = parse_shape(shape)
self._shape: List[sympy.Expr] = shape

@property
def dtype(self) -> TensorProto.DataType:
Expand All @@ -66,27 +93,42 @@ def dtype(self) -> TensorProto.DataType:
def shape(self) -> List[sympy.Expr]:
return self._shape

def update_shape(self, symbolics: SymbolicDSU):
self._shape = [symbolics.find(dim) if dim.is_symbol else dim for dim in self._shape]


def _infer_elementwise_shape(input_infos: List[TensorInfo]) -> List[sympy.Expr]:
def _infer_elementwise_shape(input_infos: List[TensorInfo], symbolics: SymbolicDSU) -> List[sympy.Expr]:
max_len = max([len(input_info.shape) for input_info in input_infos])
output_shape: List[sympy.Expr] = [sympy.Integer(1)] * max_len
for input_info in input_infos:
offset = max_len - len(input_info.shape)
for i in range(len(input_info.shape)):
if not input_info.shape[i].is_number or input_info.shape[i] != 1:
output_shape[i + offset] = input_info.shape[i]
for idx, dim in enumerate(input_info.shape):
if not dim.is_number or dim != 1:
if not output_shape[idx + offset].is_number or output_shape[idx + offset] != 1:
symbolics.union(output_shape[idx + offset], dim)
else:
output_shape[idx + offset] = dim
return output_shape


def _infer_elementwise(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
return [TensorInfo(input_infos[0].dtype, _infer_elementwise_shape(input_infos))]
def _infer_elementwise(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
return [TensorInfo(input_infos[0].dtype, _infer_elementwise_shape(input_infos, symbolics))]


def _infer_where(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
return [TensorInfo(input_infos[1].dtype, _infer_elementwise_shape(input_infos))]
def _infer_where(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
return [TensorInfo(input_infos[1].dtype, _infer_elementwise_shape(input_infos, symbolics))]


def _infer_reduction(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
def _infer_reduction(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
input_rank = len(input_infos[0].shape)
keep_dims, axes = get_reduce_info(node, graph, input_rank)
axes = [axis + input_rank if axis < 0 else axis for axis in axes]
Expand All @@ -98,17 +140,26 @@ def _infer_reduction(node: NodeProto, input_infos: List[TensorInfo], graph: Grap
return [TensorInfo(input_infos[0].dtype, shape)]


def _infer_unary(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
def _infer_unary(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
return [input_infos[0]]


def _infer_cast(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
def _infer_cast(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
dtype = get_attribute(node, "to", TensorProto.UNDEFINED)
assert dtype != TensorProto.UNDEFINED
return [TensorInfo(dtype, input_infos[0].shape)]


def _infer_dropout(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
def _infer_dropout(
node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
# pylint: disable=unused-argument
return [input_infos[0], TensorInfo(TensorProto.BOOL, input_infos[0].shape)]


Expand Down Expand Up @@ -138,10 +189,12 @@ class TypeAndShapeInfer:
}

@classmethod
def infer(cls, node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]:
def infer(
cls, node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU
) -> List[TensorInfo]:
if node.op_type not in cls._INFER_FUNC_MAP:
raise NotImplementedError(f"Unsupported op type: {node.op_type}")
return cls._INFER_FUNC_MAP[node.op_type](node, input_infos, graph)
return cls._INFER_FUNC_MAP[node.op_type](node, input_infos, graph, symbolics)


class AutotuneConfigs:
Expand All @@ -152,9 +205,30 @@ class AutotuneConfigs:
If it's reduction kernel on last contiguous dimensions, the contiguous flag is True.
"""

def __init__(self, x_numel: int, r_numel: int, contiguous: bool):
self.configs: List[Tuple[int, int, int]] = self._gen_autotune_configs(x_numel, r_numel, contiguous)
self.requires_for_loop: bool = any(config[1] < r_numel for config in self.configs)
def __init__(self, x_numel: sympy.Expr, r_numel: sympy.Expr, contiguous: bool):
x_numel_int = (
int(x_numel)
if x_numel.is_number
else int(
x_numel.subs(
{symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in x_numel.free_symbols}
)
)
)
r_numel_int = (
int(r_numel)
if r_numel.is_number
else int(
r_numel.subs(
{symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in r_numel.free_symbols}
)
)
)
self.configs: List[Tuple[int, int, int]] = self._gen_autotune_configs(x_numel_int, r_numel_int, contiguous)
# If there is symbolic shape, we will not tune the kernel.
if not x_numel.is_number or not r_numel.is_number:
self.configs = self.configs[-1:]
self.requires_for_loop: bool = any(config[1] < r_numel_int for config in self.configs)

def _num_warps(self, x: int, r: int) -> int:
return min(max(x * r // 256, 2), 8)
Expand Down
15 changes: 11 additions & 4 deletions orttraining/orttraining/python/training/ort_triton/_decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _get_dtype_and_shape(self, arg_name: str, **kwargs):
arg_info = node_arg_infos[arg_name]
return arg_info.dtype, arg_info.shape

def _decompose_elementwise_precision(self, node: NodeProto, graph: GraphProto, **kwargs):
def _decompose_elementwise_precision(self, node: NodeProto, **kwargs):
x = node.input[0]
dtype, _ = self._get_dtype_and_shape(x, **kwargs)
if not _is_half_dtype(dtype):
Expand All @@ -79,15 +79,19 @@ def _decompose_elementwise_precision(self, node: NodeProto, graph: GraphProto, *
return [*cast_nodes, op_node, cast_node1]

def Exp(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
return self._decompose_elementwise_precision(node, graph, **kwargs)
# pylint: disable=unused-argument
return self._decompose_elementwise_precision(node, **kwargs)

def Pow(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
return self._decompose_elementwise_precision(node, graph, **kwargs)
# pylint: disable=unused-argument
return self._decompose_elementwise_precision(node, **kwargs)

def Sqrt(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
return self._decompose_elementwise_precision(node, graph, **kwargs)
# pylint: disable=unused-argument
return self._decompose_elementwise_precision(node, **kwargs)

def LayerNormalization(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
# pylint: disable=unused-argument
node_name = node.name
x = node.input[0]
w = node.input[1]
Expand Down Expand Up @@ -153,6 +157,7 @@ def LayerNormalization(self, node: NodeProto, graph: GraphProto, **kwargs): # n
]

def LayerNormalizationGrad(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
# pylint: disable=unused-argument
node_name = node.name
dy = node.input[0]
x = node.input[1]
Expand Down Expand Up @@ -241,6 +246,7 @@ def LayerNormalizationGrad(self, node: NodeProto, graph: GraphProto, **kwargs):
return decomposed_nodes

def Softmax(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
# pylint: disable=unused-argument
node_name = node.name
x = node.input[0]
y = node.output[0]
Expand All @@ -259,6 +265,7 @@ def Softmax(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
return [max_node, sub_node, exp_node, sum_node, div_node]

def SoftmaxGrad_13(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802
# pylint: disable=unused-argument
node_name = node.name
dy = node.input[0]
y = node.input[1]
Expand Down
Loading

0 comments on commit 4a82030

Please sign in to comment.