diff --git a/nncf/torch/dynamic_graph/graph.py b/nncf/torch/dynamic_graph/graph.py index a67bdec77f4..b46db1f654e 100644 --- a/nncf/torch/dynamic_graph/graph.py +++ b/nncf/torch/dynamic_graph/graph.py @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -from collections import Counter + from typing import Dict from typing import List from typing import Optional @@ -70,34 +70,19 @@ def __call__(self, node_inputs: List[TensorMeta], real_inputs: List[TensorMeta], class DefaultInputsMatcher(InputsMatcher): - def __call__(self, saved_inputs: List[TensorMeta], actual_inputs: List[TensorMeta], + def __call__(self, node_inputs: List[TensorMeta], real_inputs: List[TensorMeta], tm_comparators: List[TensorMetaComparator]) -> bool: - if saved_inputs is None and actual_inputs: + if node_inputs is None and real_inputs: return False - matched_with_unexpected_tensors = False - for saved_input, actual_input in zip(saved_inputs, actual_inputs): + for saved_input, actual_input in zip(node_inputs, real_inputs): if saved_input is None and actual_input is None: continue - if (saved_input is None) and (actual_input is not None): - # torch.Tensor.size() seems to return ints when not tracing ONNX - # and tensors when tracing ONNX. This breaks input-based node matching whenever - # torch.Tensor.size() return value is passed into a NNCF-traced operation (such as `view`) - # because at graph building time it expected to see ints as args and now it sees tensors. - # To mitigate this, will only match inputs against the positions which had tensors during build-time - # and disregard the rest of the argument positions. - matched_with_unexpected_tensors = True - continue - if (saved_input is not None) and (actual_input is None): + if (saved_input is None) != (actual_input is None): return False for tm_comparator in tm_comparators: if not tm_comparator(saved_input, actual_input): return False - if matched_with_unexpected_tensors: - nncf_logger.debug("Had to match a node to an op which has tensors at positions where there were no tensors " - "at graph building time:\nNode input metas: {}, but op input metas: {}".format( - saved_inputs, actual_inputs - )) return True @@ -124,14 +109,9 @@ def __init__(self, DefaultTensorMetaComparator()] self.input_matcher = input_matcher if input_matcher else DefaultInputsMatcher() - def __eq__(self, other): - return self.op_address == other.op_address and Counter(self.tensor_metas) == Counter(other.tensor_metas) - - def matches_saved_inputs_from(self, other: 'OperationExecutionContext'): - # WARNING: not commutative - return self.op_address == other.op_address and self.input_matcher(other.tensor_metas, - self.tensor_metas, - self.tm_comparators) + def __eq__(self, other: 'OperationExecutionContext'): + return (self.op_address == other.op_address) and \ + self.input_matcher(self.tensor_metas, other.tensor_metas, self.tm_comparators) def __hash__(self): return hash((self.operator_name, tuple(self.scope_in_model), self.call_order, @@ -201,7 +181,7 @@ def _find_nodes_with_matching_context_among_inputless(self, op_exec_context: Ope -> Dict[str, DynamicGraphNode]: node_candidates = {} for nx_node_key, node in self._inputless_nodes.items(): - if op_exec_context.matches_saved_inputs_from(node.op_exec_context): + if node.op_exec_context == op_exec_context: node_candidates[nx_node_key] = node return node_candidates @@ -214,7 +194,7 @@ def _find_nodes_with_matching_context_and_inputs(self, op_exec_context: Operatio creator_id = info.creator_id for successor_node_key in self._nx_graph.successors(self._node_id_to_key_dict[creator_id]): successor_node = self._nx_graph.nodes[successor_node_key] - if op_exec_context.matches_saved_inputs_from(successor_node[DynamicGraph.OP_EXEC_CONTEXT_NODE_ATTR]): + if op_exec_context == successor_node[DynamicGraph.OP_EXEC_CONTEXT_NODE_ATTR]: nx_node_candidates[successor_node_key] = successor_node node_candidates = {} # type: Dict[str, DynamicGraphNode] @@ -412,7 +392,7 @@ def _match_first_iteration_nodes(self, op_exec_context: OperationExecutionContex for iter_scope in iter_scopes: if iter_scope in self._first_iteration_nodes: for name, node in self._first_iteration_nodes[iter_scope].items(): - if op_exec_context.matches_saved_inputs_from(node.op_exec_context): + if op_exec_context == node.op_exec_context: node_candidates[name] = node break if node_candidates: