Skip to content

Commit

Permalink
[WebNN EP] Add data type constraint (microsoft#20779)
Browse files Browse the repository at this point in the history
WebNN spec has added data type constraint for every op, and its CPU
backend (currently is TFLite) has additional constraint. Add
corresponding constraint to each op in WebNN EP.

Note: Temporarily disable fp16 for CPU backend as which is planned to be
ready in Chromium next month.
  • Loading branch information
Honry committed May 29, 2024
1 parent e77f238 commit 9ea9f9e
Show file tree
Hide file tree
Showing 25 changed files with 680 additions and 76 deletions.
14 changes: 4 additions & 10 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,10 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
return supported_node_groups;
}

bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_type) {
// Current data type implementation status of WebNN is inconsistent along with different backends,
// The XNNPack backend supports only FP32, while the DML backend POC supports more.
if (device_type == WebnnDeviceType::CPU) {
return std::find(supported_cpu_data_types.begin(), supported_cpu_data_types.end(), data_type) !=
supported_cpu_data_types.end();
} else {
return std::find(supported_gpu_data_types.begin(), supported_gpu_data_types.end(), data_type) !=
supported_gpu_data_types.end();
}
bool IsSupportedDataType(const int32_t data_type,
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types) {
return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) !=
supported_data_types.end();
}

bool IsValidMultidirectionalBroadcast(std::vector<int64_t>& shape_a,
Expand Down
9 changes: 3 additions & 6 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,7 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn
return true;
}

constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 1> supported_cpu_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
};

constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 9> supported_gpu_data_types = {
static const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> webnn_supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_BOOL,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
ONNX_NAMESPACE::TensorProto_DataType_UINT8,
Expand All @@ -278,7 +274,8 @@ constexpr std::array<ONNX_NAMESPACE::TensorProto_DataType, 9> supported_gpu_data
ONNX_NAMESPACE::TensorProto_DataType_UINT64,
};

bool IsSupportedDataType(const int32_t data_type, const WebnnDeviceType device_type);
bool IsSupportedDataType(const int32_t data_type,
const std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType>& supported_data_types);

bool IsValidMultidirectionalBroadcast(std::vector<int64_t>& shape_a,
std::vector<int64_t>& shape_b,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class ActivationOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -81,6 +83,44 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi
return true;
}

bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;

std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
// WebNN relu op supports float32, float16, int32, int8 input data types.
if (op_type == "Relu") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
};
// WebNN CPU backend does not support int32 data type for relu.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
}
} else { // Others only support float32 and float16.
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
};
}

if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}

return true;
}

void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -77,6 +79,31 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
return true;
}

bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;

std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}

if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}

return true;
}

void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
Expand Down
16 changes: 14 additions & 2 deletions onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,23 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d
}
}

// WebNN CPU backend (TFLite) will enable float16 input data type soon,
// temporarily fallback float16 input data type for WebNN CPU.
if (device_type == WebnnDeviceType::CPU) {
const auto& input = *node.InputDefs()[0];

int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
return false;
}

return HasSupportedInputsImpl(node, device_type, logger);
}

bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
const WebnnDeviceType device_type,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
const auto& input = *node.InputDefs()[0];
Expand All @@ -86,7 +98,7 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
if (!GetType(input, input_type, logger))
return false;

if (!IsSupportedDataType(input_type, device_type)) {
if (!IsSupportedDataType(input_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << node.OpType()
<< "] Input type: [" << input_type
<< "] is not supported for now";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -72,6 +74,49 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return true;
}

bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;
int32_t input1_type;

if (!GetType(*input_defs[0], input0_type, logger) ||
!GetType(*input_defs[1], input1_type, logger))
return false;

std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types;
// WebNN prelu op only supports float32, float16, int32, int8 input data types.
if (op_type == "Prelu") {
supported_data_types = {
ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
ONNX_NAMESPACE::TensorProto_DataType_FLOAT16,
ONNX_NAMESPACE::TensorProto_DataType_INT32,
ONNX_NAMESPACE::TensorProto_DataType_INT8,
};
// WebNN CPU backend doesn't support int32 for prelu.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
}
} else {
supported_data_types = webnn_supported_data_types;
}
if (!IsSupportedDataType(input0_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input0_type
<< "] is not supported for now";
return false;
}

if (input0_type != input1_type) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input data types should be the same.";
return false;
}

return true;
}

void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class CastOpBuilder : public BaseOpBuilder {
// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -80,12 +80,12 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
const Node& node,
const WebnnDeviceType device_type,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
NodeAttrHelper helper(node);
// Check cast output type.
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
if (!IsSupportedDataType(to_type, device_type)) {
if (!IsSupportedDataType(to_type, webnn_supported_data_types)) {
LOGS(logger, VERBOSE) << "Invalid cast to type " << to_type << ".";
return false;
}
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class ClipOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -71,6 +73,33 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return GetClipMinMax(initializers, node, min, max, logger);
}

bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;

std::unordered_set<ONNX_NAMESPACE::TensorProto_DataType> supported_data_types = webnn_supported_data_types;
// WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp.
if (device_type == WebnnDeviceType::CPU) {
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64);
supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64);
}

if (!IsSupportedDataType(input_type, supported_data_types)) {
LOGS(logger, VERBOSE) << "[" << op_type
<< "] Input type: [" << input_type
<< "] is not supported for now";
return false;
}

return true;
}

void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<ClipOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
Expand Down
Loading

0 comments on commit 9ea9f9e

Please sign in to comment.