Skip to content

Commit

Permalink
Cherry pick 3 fixes to rel-1.2.0 (#3158)
Browse files Browse the repository at this point in the history
* Publish release symbols (#3152)

* Publish release symbols

* Publish symbols if IsReleaseBuild

* Disable delayload for cuda dlls (#3147)

This change fixes #3129. When running onnxruntime as dll on Windows, CUDA does some internal cleanups when process exits. After this, any call to CUDA would cause crash. Delayload makes thread_local destructor to happen after CUDA cleanup, thus the crash.

* Update Gelu Fusion to support new graph pattern from PyTorch 1.4 (#3148)

* update GeluFusion to support pattern from PyTorch 1.4; 
* Fix a bug that missing the check of an edge between mul2 and root.
* update script to fuse gelu from PyTorch 1.4
* Add test for python optimizer

Co-authored-by: Tiago Koji Castro Shibata <ticastro@microsoft.com>
Co-authored-by: KeDengMS <kedeng@microsoft.com>
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
  • Loading branch information
4 people committed Mar 7, 2020
1 parent b71554c commit dacb42f
Show file tree
Hide file tree
Showing 18 changed files with 372 additions and 62 deletions.
13 changes: 7 additions & 6 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -724,12 +724,13 @@ if (onnxruntime_USE_CUDA)
if (WIN32)
link_directories(${onnxruntime_CUDNN_HOME}/lib/x64)

file(GLOB cuda_dll_paths "${onnxruntime_CUDA_HOME}/bin/cublas64_*" "${onnxruntime_CUDA_HOME}/bin/cudart64_*" "${onnxruntime_CUDA_HOME}/bin/curand64_*")
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:cudnn64_7.dll")
foreach(cuda_dll_path ${cuda_dll_paths})
get_filename_component(cuda_dll_file_name ${cuda_dll_path} NAME)
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:${cuda_dll_file_name}")
endforeach(cuda_dll_path)
# delayload causes crash on exit, so disable for now
#file(GLOB cuda_dll_paths "${onnxruntime_CUDA_HOME}/bin/cublas64_*" "${onnxruntime_CUDA_HOME}/bin/cudart64_*" "${onnxruntime_CUDA_HOME}/bin/curand64_*")
#set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:cudnn64_7.dll")
#foreach(cuda_dll_path ${cuda_dll_paths})
# get_filename_component(cuda_dll_file_name ${cuda_dll_path} NAME)
# set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:${cuda_dll_file_name}")
#endforeach(cuda_dll_path)

else()
link_directories(${onnxruntime_CUDNN_HOME}/lib64)
Expand Down
112 changes: 78 additions & 34 deletions onnxruntime/core/optimizer/gelu_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,25 @@ static bool IsSupportedDataType(const Node& node) {
}
return true;
}

/*
This function fuses subgraph like the following into one Gelu node.
Subgraph pattern 1:
+-------Mul(0.5)---------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul ==>
(B=1.4142...) (1)
Subgraph pattern 2:
+------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul ==>
(B=1.4142...) (1) (0.5)
After Fusion:
[root]--> Gelu ==>
*/
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
Expand Down Expand Up @@ -68,13 +86,9 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
continue;
}

// Check the other input node(e.g. not of type Erf) is 1.0f.
const Node& add_first_input_node = *(add_node.InputNodesBegin());
int add_const_input_index = 0;
if (add_first_input_node.OpType().compare("Erf") == 0) {
add_const_input_index = 1;
}
const auto& add_const_input_arg = add_node.InputDefs()[add_const_input_index];
// Check the other input node (e.g. not the Erf) is 1.0f.
bool is_erf_first_input = (add_node.InputDefs()[0]->Name() == erf_node.MutableOutputDefs()[0]->Name());
const auto& add_const_input_arg = add_node.InputDefs()[is_erf_first_input ? 1 : 0];
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *add_const_input_arg, 1.0f, true)) {
continue;
}
Expand All @@ -87,35 +101,60 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
continue;
}

const Node* p_mul2_node = nullptr;
for (auto iter = mul_node.InputNodesBegin(); iter != mul_node.InputNodesEnd(); ++iter) {
if ((*iter).OpType().compare("Mul") == 0) {
// find the other input node of Mul
p_mul2_node = &(*iter);
break;
bool is_pattern_1 = true;
const Node* p_mul2_node = graph_utils::FirstParentByType(mul_node, "Mul");
if (p_mul2_node != nullptr) {
// Match subgraph pattern 1
Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
mul2_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(mul2_node)) {
continue;
}
}
if (p_mul2_node == nullptr) {
continue;
}

Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
mul2_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(mul2_node)) {
continue;
}
// One input of mul2_node shall be the subgraph input
auto root_index = optimizer_utils::IndexOfNodeInput(*p_mul2_node, *div.InputDefs()[0]);
if (root_index < 0)
continue;

// Check the other input node(e.g. not of type Add) is 0.5f.
int mul_const_input_index = 0;
if (mul2_node.InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) {
mul_const_input_index = 1;
}
// Check the other input node is 0.5f.
int mul_const_input_index = (root_index == 0 ? 1 : 0);

const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *mul_const_input_arg, 0.5f, true)) {
continue;
const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *mul_const_input_arg, 0.5f, true)) {
continue;
}
} else {
is_pattern_1 = false;

// Match subgraph pattern 2
if (mul_node.GetOutputEdgesCount() != 1) {
continue;
}

// Another input of Mul node shall be the subgraph input.
auto root_index = optimizer_utils::IndexOfNodeInput(mul_node, *div.InputDefs()[0]);
if (root_index < 0)
continue;

Node& mul2_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!IsSupportedDataType(mul_node)) {
continue;
}

int mul_const_input_index = 0;
if (mul2_node.InputDefs()[0]->Name() == mul_node.MutableOutputDefs()[0]->Name()) {
mul_const_input_index = 1;
}
const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *mul_const_input_arg, 0.5f, true)) {
continue;
}

p_mul2_node = &mul2_node;
}

const std::vector<NodeArg*> gelu_input_defs{div.MutableInputDefs()[0]};
Expand All @@ -131,7 +170,12 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
// move input edges to div (first in list) across to the gelu_node.
// move output definitions and output edges from mul_node (last in list) to gelu_node.
// remove all the other nodes.
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul2_node, mul_node}, gelu_node);
Node& mul2 = *graph.GetNode(p_mul2_node->Index());
if (is_pattern_1) {
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul2, mul_node}, gelu_node);
} else {
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul_node, mul2}, gelu_node);
}

modified = true;
}
Expand Down
48 changes: 35 additions & 13 deletions onnxruntime/python/tools/bert/BertOnnxModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,24 @@ def fuse_gelu(self, gelu_op_name):

"""
Fuse Gelu with Erf into one node:
+-------Mul(B=0.5)-------------------+
Pattern 1:
+-------Mul(0.5)---------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->
(B=1.4142...) (B=1)
(B=1.4142...) (1)
Pattern 2:
+------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
(B=1.4142...) (1) (0.5)
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
def fuse_gelu_with_elf(self, gelu_op_name):
logger.debug(f"start fuse_gelu_with_elf({gelu_op_name})")
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()

Expand Down Expand Up @@ -276,25 +285,38 @@ def fuse_gelu_with_elf(self, gelu_op_name):
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
continue

root_node = self.get_parent(div, 0, output_name_to_node)
if root_node is None:
continue
subgraph_input = div.input[0]

mul_half = self.match_parent(mul_after_erf, 'Mul', None, output_name_to_node)
if mul_half is None:
continue
another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
if subgraph_input == mul_after_erf.input[another]: # pattern 2
children = input_name_to_nodes[mul_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != 'Mul':
continue
mul_half = children[0]
if not self.has_constant_input(mul_half, 0.5):
continue
subgraph_output = mul_half.output[0]
else: # pattern 1
mul_half = self.match_parent(mul_after_erf, 'Mul', another, output_name_to_node)
if mul_half is None:
continue

if not self.has_constant_input(mul_half, 0.5):
continue
if not self.has_constant_input(mul_half, 0.5):
continue

if subgraph_input not in mul_half.input:
continue

subgraph_output = mul_after_erf.output[0]

subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul_after_erf.output[0]], input_name_to_nodes, output_name_to_node):
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
continue

nodes_to_remove.extend(subgraph_nodes)
gelu_node = onnx.helper.make_node(gelu_op_name,
inputs=[root_node.output[0]],
outputs=[mul_after_erf.output[0]])
inputs=[subgraph_input],
outputs=[subgraph_output])
gelu_node.domain = "com.microsoft"
nodes_to_add.append(gelu_node)

Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/python/tools/bert/test_bert_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
BERT_TEST_MODELS = {
"bert_pytorch_0": 'test_data\\bert_squad_pytorch1.4_opset11\\BertForQuestionAnswering_0.onnx',
"bert_pytorch_1": 'test_data\\bert_squad_pytorch1.4_opset11\\BertForQuestionAnswering_1.onnx',
"bert_squad_pytorch1.4_opset10_fp32": 'test_data\\bert_squad_pytorch1.4_opset10_fp32\\BertForQuestionAnswering.onnx',
"bert_keras_0": 'test_data\\bert_mrpc_tensorflow2.1_opset10\\TFBertForSequenceClassification_1.onnx'
}

Expand Down Expand Up @@ -155,6 +156,13 @@ def test_pytorch_model_0_gpu(self):
}
self.verify_node_count(bert_model, expected_node_count)

def test_pytorch_model_2_cpu(self):
input = BERT_TEST_MODELS['bert_squad_pytorch1.4_opset10_fp32']
bert_model = optimize_model(input, 'bert', gpu_only=False,
num_heads=2, hidden_size=8, sequence_length=10,
input_int32=False, float16=False)
self.assertTrue(bert_model.is_fully_optimized())

def test_keras_model_1_cpu(self):
input = BERT_TEST_MODELS['bert_keras_0']

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

BstartJ(���A2I<W�1<nR�<G�;�^<�q?;���<
r�;��<
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

BendJ(���=Fڞ=L��=QR�=�w�=6\�=��=���=��=Jg�=
57 changes: 57 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,63 @@ TEST(GraphTransformationTests, GeluFusionTest) {
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) {
auto model_uri = MODEL_FOLDER "fusion/gelu_format2_0.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST(GraphTransformationTests, GeluFusionTestFormat2) {
auto model_uri = MODEL_FOLDER "fusion/gelu_format2_1.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST(GraphTransformationTests, GeluFusionTestFormat2GraphInput) {
auto model_uri = MODEL_FOLDER "fusion/gelu_format2_1_use_graph_input.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST(GraphTransformationTests, BiasGeluTest) {
auto model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion.onnx";
std::shared_ptr<Model> p_model;
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/test/shared_lib/test_model_loading.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Licensed under the MIT License.

#include "core/session/onnxruntime_cxx_api.h"
#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_factory.h"
#endif
#include <fstream>
#include "test_fixture.h"
#include "file_util.h"
Expand All @@ -25,6 +28,12 @@ TEST(CApiTest, model_from_array) {

Ort::SessionOptions so;
Ort::Session session(*ort_env.get(), buffer.data(), buffer.size(), so);

#ifdef USE_CUDA
// test with CUDA provider when using onnxruntime as dll
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(so, 0));
Ort::Session session_cuda(*ort_env.get(), buffer.data(), buffer.size(), so);
#endif
}
} // namespace test
} // namespace onnxruntime
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit dacb42f

Please sign in to comment.