Skip to content

Commit

Permalink
Add fusions for OpenAI CLIP (#20721)
Browse files Browse the repository at this point in the history
### Description
This PR adds fusions for [OpenAI's CLIP
model](https://huggingface.co/openai/clip-vit-large-patch14-336). Here
is an example of how to run the ORT transformer optimizer for the linked
CLIP model.

```
$ git clone https://github.com/microsoft/onnxruntime
$ cd onnxruntime/onnxruntime/python/tools/transformers
$ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0
```

### Motivation and Context
This PR helps optimize multi-modal models that use CLIP for the vision
encoder.
  • Loading branch information
kunal-vaishnavi authored May 18, 2024
1 parent 5d07291 commit ca22a5a
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 31 deletions.
81 changes: 60 additions & 21 deletions onnxruntime/python/tools/transformers/fusion_attention_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,33 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
else:
# Deal with the first attention after the embedding layer.
for i in [0, 1]:
node_before_layer_norm = self.model.match_parent(normalize_node, "Add", i)
node_before_layer_norm = None

node_before_layer_norm_1 = self.model.match_parent(normalize_node, "Add", i)
node_before_layer_norm_2 = self.model.match_parent(normalize_node, "LayerNormalization", i)
if node_before_layer_norm_1 is not None:
# Add -----------+
# | |
# LayerNorm |
# | |
# LayerNorm |
# | |
# Attention subgraph |
# | |
# SkipLayerNorm ------+
node_before_layer_norm = node_before_layer_norm_1
elif node_before_layer_norm_2 is not None:
# Add
# |
# LayerNorm --------+
# | |
# LayerNorm |
# | |
# Attention subgraph |
# | |
# SkipLayerNorm ------+
node_before_layer_norm = node_before_layer_norm_2

if node_before_layer_norm is None:
continue
child = self.model.find_first_child_by_type(
Expand Down Expand Up @@ -130,20 +156,32 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
return
(_, _, reshape_v, add_v, matmul_v) = v_nodes

add_mask = None
add_mask_indices = []
qk_nodes = self.model.match_parent_path(
qk_nodes = None
qk_nodes_1 = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Reshape", "Add", "Reshape", "MatMul"],
[0, 0, 0, None, 0],
return_indice=add_mask_indices,
)
if qk_nodes is None:
qk_nodes_2 = self.model.match_parent_path(
matmul_qkv,
["Softmax", "MatMul"],
[0, 0],
)
if qk_nodes_1 is not None:
qk_nodes = qk_nodes_1
assert len(add_mask_indices) == 1
causal_mask_input_index = 1 - add_mask_indices[0]

(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes
elif qk_nodes_2 is not None:
qk_nodes = qk_nodes_2
(_softmax_qk, matmul_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return
assert len(add_mask_indices) == 1
causal_mask_input_index = 1 - add_mask_indices[0]

(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes

q_nodes = self.model.match_parent_path(
matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None]
Expand Down Expand Up @@ -172,23 +210,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):

attention_last_node = reshape_qkv

# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
# of computing causal mask.
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
# If the model is exported with batch_size == 1, there is no Concat node
if add_mask is not None:
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
# of computing causal mask.
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return
# If the model is exported with batch_size == 1, there is no Concat node
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return

new_node = self.create_attention_node(
mask_index=None,
Expand All @@ -204,7 +243,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
output=attention_last_node.output[0],
add_qk_str=None,
scale=None,
causal=True,
causal=(add_mask is not None),
)
if new_node is None:
return
Expand Down
33 changes: 23 additions & 10 deletions onnxruntime/python/tools/transformers/fusion_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
| |
+----------------------+
"""
subgraph_nodes = []
children = self.model.get_children(node, input_name_to_nodes)
if len(children) == 0 or len(children) > 2:
return
Expand All @@ -53,20 +54,24 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):

div_node = None
for child in children:
div_node = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
if div_node is not None:
break
# Check if Sub --> Div exists
div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)

# Check if Sub --> Cast --> Div
div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[])

if div_node_1 is not None:
div_node = div_node_1
elif div_node_2 is not None:
div_node = div_node_2[-1]
if div_node is None:
return

path_id, parent_nodes, _ = self.model.match_parent_paths(
div_node,
[
(["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
(
["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
[1, 0, 0, 0, 0, 0],
),
(["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
],
output_name_to_node,
)
Expand All @@ -87,15 +92,22 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
if self.model.find_constant_input(pow_node, 2.0) != 1:
return

mul_node = input_name_to_nodes[div_node.output[0]][0]
temp_node = input_name_to_nodes[div_node.output[0]][0]
if temp_node.op_type == "Cast":
# Div --> Cast --> Mul
subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes
mul_node = input_name_to_nodes[temp_node.output[0]][0]
else:
# Div --> Mul
mul_node = temp_node
if mul_node.op_type != "Mul":
return

last_add_node = input_name_to_nodes[mul_node.output[0]][0]
if last_add_node.op_type != "Add":
return

subgraph_nodes = [node]
subgraph_nodes.append(node)
subgraph_nodes.extend(children)
subgraph_nodes.extend(parent_nodes[:-1])

Expand All @@ -109,7 +121,8 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
logger.debug("It is not safe to fuse LayerNormalization node. Skip")
return

weight_input = mul_node.input[1 - self.model.input_index(div_node.output[0], mul_node)]
node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"):
return

Expand Down
74 changes: 74 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_quickgelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import logging

from fusion_base import Fusion
from onnx import helper
from onnx_model import OnnxModel

logger = logging.getLogger(__name__)


class FusionQuickGelu(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "QuickGelu", ["Mul"])

def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Fuse the following subgraph to `QuickGelu`
#
# root_input
# / \
# | Mul ----+
# | (B = ~1.702) |
# \ | |
# \ Sigmoid |---- `QuickGelu`
# \ / |
# \ / |
# Mul ----+
# |
# root_output

if node.op_type != "Mul":
logger.debug("fuse_quickgelu: failed to match second Mul node")
return

second_mul_node = node
root_input = second_mul_node.input[0]

sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1])
if sigmoid_node is None:
logger.debug("fuse_quickgelu: failed to match Sigmoid node")
return
sigmoid_node = sigmoid_node[0]

first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0])
if first_mul_node is None:
logger.debug("fuse_quickgelu: failed to match first Mul node")
return
first_mul_node = first_mul_node[0]

approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item()
if abs(approximation_value - 1.7021484375) >= 1e-3:
logger.debug("fuse_quickgelu: failed to match approximation value")
return

if first_mul_node.input[0] != root_input:
logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input")
return

new_node = helper.make_node(
"QuickGelu",
inputs=[root_input],
outputs=[second_mul_node.output[0]],
name=self.model.create_node_name("QuickGelu"),
)
new_node.domain = "com.microsoft"
new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)])

self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node])
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.increase_counter("QuickGelu")
3 changes: 3 additions & 0 deletions onnxruntime/python/tools/transformers/onnx_model_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fusion_qordered_gelu import FusionQOrderedGelu
from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
from fusion_qordered_matmul import FusionQOrderedMatMul
from fusion_quickgelu import FusionQuickGelu
from fusion_reshape import FusionReshape
from fusion_rotary_attention import FusionRotaryEmbeddings
from fusion_shape import FusionShape
Expand Down Expand Up @@ -65,6 +66,8 @@ def fuse_gelu(self):
fusion.apply()
fusion = FusionFastGelu(self)
fusion.apply()
fusion = FusionQuickGelu(self)
fusion.apply()
# Only relevant in models with Q-DQ nodes
fusion = FusionQOrderedGelu(self)
fusion.apply()
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/transformers/onnx_model_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_fused_operator_statistics(self):
ops = [
"Attention",
"LayerNormalization",
"QuickGelu",
"SkipLayerNormalization",
]
for op in ops:
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/test/python/transformers/test_gelu_fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))


class HuggingfaceQuickGelu(torch.nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)


class MegatronGelu(torch.nn.Module):
def forward(self, x):
# The original implementation using ones_like, which might cause problem for input with dynamic axes in onnx.
Expand All @@ -36,6 +41,7 @@ def forward(self, x):
test_cases = [
("huggingface", "Gelu", HuggingfaceGelu),
("huggingface", "FastGelu", HuggingfaceFastGelu),
("huggingface", "QuickGelu", HuggingfaceQuickGelu),
("megatron", "Gelu", MegatronGelu),
("megatron", "FastGelu", MegatronFastGelu),
]
Expand Down

0 comments on commit ca22a5a

Please sign in to comment.