Skip to content

Commit

Permalink
Fix view with size (openvinotoolkit#932)
Browse files Browse the repository at this point in the history
* Fix view with size in BERT

* Renamings

* Fix equalities
  • Loading branch information
vshampor committed Sep 22, 2021
1 parent 4e3d61a commit 818971a
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions nncf/torch/dynamic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from collections import Counter
from typing import Dict
from typing import List
from typing import Optional
Expand Down Expand Up @@ -70,19 +70,34 @@ def __call__(self, node_inputs: List[TensorMeta], real_inputs: List[TensorMeta],


class DefaultInputsMatcher(InputsMatcher):
def __call__(self, node_inputs: List[TensorMeta], real_inputs: List[TensorMeta],
def __call__(self, saved_inputs: List[TensorMeta], actual_inputs: List[TensorMeta],
tm_comparators: List[TensorMetaComparator]) -> bool:
if node_inputs is None and real_inputs:
if saved_inputs is None and actual_inputs:
return False

for saved_input, actual_input in zip(node_inputs, real_inputs):
matched_with_unexpected_tensors = False
for saved_input, actual_input in zip(saved_inputs, actual_inputs):
if saved_input is None and actual_input is None:
continue
if (saved_input is None) != (actual_input is None):
if (saved_input is None) and (actual_input is not None):
# torch.Tensor.size() seems to return ints when not tracing ONNX
# and tensors when tracing ONNX. This breaks input-based node matching whenever
# torch.Tensor.size() return value is passed into a NNCF-traced operation (such as `view`)
# because at graph building time it expected to see ints as args and now it sees tensors.
# To mitigate this, will only match inputs against the positions which had tensors during build-time
# and disregard the rest of the argument positions.
matched_with_unexpected_tensors = True
continue
if (saved_input is not None) and (actual_input is None):
return False
for tm_comparator in tm_comparators:
if not tm_comparator(saved_input, actual_input):
return False
if matched_with_unexpected_tensors:
nncf_logger.debug("Had to match a node to an op which has tensors at positions where there were no tensors "
"at graph building time:\nNode input metas: {}, but op input metas: {}".format(
saved_inputs, actual_inputs
))
return True


Expand All @@ -109,9 +124,14 @@ def __init__(self,
DefaultTensorMetaComparator()]
self.input_matcher = input_matcher if input_matcher else DefaultInputsMatcher()

def __eq__(self, other: 'OperationExecutionContext'):
return (self.op_address == other.op_address) and \
self.input_matcher(self.tensor_metas, other.tensor_metas, self.tm_comparators)
def __eq__(self, other):
return self.op_address == other.op_address and Counter(self.tensor_metas) == Counter(other.tensor_metas)

def matches_saved_inputs_from(self, other: 'OperationExecutionContext'):
# WARNING: not commutative
return self.op_address == other.op_address and self.input_matcher(other.tensor_metas,
self.tensor_metas,
self.tm_comparators)

def __hash__(self):
return hash((self.operator_name, tuple(self.scope_in_model), self.call_order,
Expand Down Expand Up @@ -181,7 +201,7 @@ def _find_nodes_with_matching_context_among_inputless(self, op_exec_context: Ope
-> Dict[str, DynamicGraphNode]:
node_candidates = {}
for nx_node_key, node in self._inputless_nodes.items():
if node.op_exec_context == op_exec_context:
if op_exec_context.matches_saved_inputs_from(node.op_exec_context):
node_candidates[nx_node_key] = node
return node_candidates

Expand All @@ -194,7 +214,7 @@ def _find_nodes_with_matching_context_and_inputs(self, op_exec_context: Operatio
creator_id = info.creator_id
for successor_node_key in self._nx_graph.successors(self._node_id_to_key_dict[creator_id]):
successor_node = self._nx_graph.nodes[successor_node_key]
if op_exec_context == successor_node[DynamicGraph.OP_EXEC_CONTEXT_NODE_ATTR]:
if op_exec_context.matches_saved_inputs_from(successor_node[DynamicGraph.OP_EXEC_CONTEXT_NODE_ATTR]):
nx_node_candidates[successor_node_key] = successor_node

node_candidates = {} # type: Dict[str, DynamicGraphNode]
Expand Down Expand Up @@ -392,7 +412,7 @@ def _match_first_iteration_nodes(self, op_exec_context: OperationExecutionContex
for iter_scope in iter_scopes:
if iter_scope in self._first_iteration_nodes:
for name, node in self._first_iteration_nodes[iter_scope].items():
if op_exec_context == node.op_exec_context:
if op_exec_context.matches_saved_inputs_from(node.op_exec_context):
node_candidates[name] = node
break
if node_candidates:
Expand Down

0 comments on commit 818971a

Please sign in to comment.