Skip to content

Commit

Permalink
A cleaner and more scalable implementation of symbolic tracing (huggi…
Browse files Browse the repository at this point in the history
…ngface#11763)

Cleaner and more scalable implementation of symbolic tracing with torch.fx, and provides support for new architectures:
- ALBERT
- DistilBERT
- MobileBERT
- MegatronBERT
- GPT2
- GPT Neo

Co-authored-by: Michael Benayoun <michael@huggingface.co>
  • Loading branch information
michaelbenayoun and michaelbenayoun committed May 20, 2021
1 parent 469384a commit f4a0d6f
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 114 deletions.
309 changes: 211 additions & 98 deletions src/transformers/modeling_fx_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,31 @@
import dis
import copy
import functools
import inspect
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from torch.fx import GraphModule, Node, Proxy, Tracer

from . import PreTrainedModel
from torch.fx import Graph, GraphModule, Node, Proxy, Tracer
from torch.fx.node import Argument

from . import (
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
GPT2DoubleHeadsModel,
PreTrainedModel,
logging,
)
from .models.auto import get_values


logger = logging.get_logger(__name__)


class HFProxy(Proxy):
Expand All @@ -21,127 +41,214 @@ def __init__(self, node: Node, tracer: Optional[Tracer] = None):
self.device = self.tracer.root.device
self.dtype = next(self.tracer.root.parameters()).dtype

def dim(self):
return len(self.tracer.encoder_shape)

def _shape(self, calling_frame):
module = calling_frame.f_locals.get("self", None)
is_decoder = hasattr(module, "is_decoder") and module.is_decoder
return list(self.tracer.decoder_shape) if is_decoder else list(self.tracer.encoder_shape)

def size(self, dim=None):
frame = inspect.currentframe()
calling_frame = frame.f_back

# self.size can be called through the shape property, in which case we need to get the outer
# frame, containing the meaningful information.
if calling_frame.f_code.co_name == "shape":
calling_frame = calling_frame.f_back

instructions = list(reversed(list(dis.get_instructions(calling_frame.f_code))[: calling_frame.f_lasti]))
code_context = inspect.getframeinfo(calling_frame).code_context[0].strip()

shape = self._shape(calling_frame)

if calling_frame.f_code.co_name == "transpose_for_scores":
# Provides the proper "x.size()" for:
# new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
shape = shape + [-1]
elif "context_layer" in calling_frame.f_locals:
# Provides the proper "context_layer.size()" for:
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
shape = shape + [-1, -1]
elif calling_frame.f_locals.get("do_cross_attention", False):
# Provides the proper shape for:
# query_length = present_key_value_state[0].shape[2]
# (modeling_t5.py)
shape = list(self.tracer.encoder_shape)
shape = shape[:1] + [-1] + shape[1:2]
elif "key_length" in code_context or "encoder_seq_length" in code_context:
shape = list(self.tracer.encoder_shape)
elif "lm_logits.size(-1)" in code_context:
shape = [self.tracer.root.config.vocab_size]
elif "start_positions" in code_context or "end_positions" in code_context:
# For question answering tasks.
shape = [1]
elif "num_choices" in code_context:
if self.tracer.num_choices <= 0:
raise ValueError("num_choices must be given to the CustomTracer for MultipleChoice tasks.")
shape = shape[:1] + [self.tracer.num_choices] + shape[1:]
elif "hidden_states.s" in code_context:
shape = shape + [self.tracer.root.config.hidden_size]
else:
# Default case:
# - If self.size is called for an unpacking, retrieves the corresponding unpacking
# instruction, and returns the shape padded as much as necessary to match the expected
# number of items.
# - If self.size is called outside of an unpacking context, simply return the shape.
is_unpack = False

for inst in instructions:
if inst.opname == "UNPACK_SEQUENCE":
is_unpack = True
break

if is_unpack and inst.argval >= 3:
shape += [self.tracer.root.config.hidden_size]
dummy_values = [1] * (inst.argval - 3)
shape += dummy_values

if dim is not None:
return shape[dim]

return tuple(shape)

@property
def shape(self):
return self.size()

def __bool__(self) -> bool:
frame = inspect.currentframe()
calling_frame = frame.f_back
code_context = inspect.getframeinfo(calling_frame).code_context[0].strip()
if calling_frame.f_code.co_name == "apply_chunking_to_forward":
# Returning True to every assertion in "apply_chuncking_to_forward"
return True
elif "assert" in code_context:
# Returning True to any assertion.
return True
elif calling_frame.f_code.co_name == "get_extended_attention_mask":
# Corresponding to:
# if causal_mask.shape[1] < attention_mask.shape[1]:
return calling_frame.f_back.f_locals["past_key_values"][0] is not None
raise NotImplementedError("__bool__ was called for CustomProxy, but this case is not covered yet.")

def __setitem__(self, key, value):
pass

def __contains__(self, key):
return False


def _wrap_method_for_model_recording(model, method_name, cache_name):
"""Helper function that wraps a torch.Tensor method to record its outputs during forward pass."""
method = getattr(torch.Tensor, method_name)

@functools.wraps(method)
def wrapped(*args, **kwargs):
if not hasattr(model, cache_name):
setattr(model, cache_name, [])
cache = getattr(model, cache_name)
res = method(*args, **kwargs)
cache.append(res)
return res

return wrapped


def _create_recorded_proxy_method(proxy, method_name, cache_name):
"""
Helper function that sets a recorded torch.Tensor method as a HFProxy method that will use the recorded values
during symbolic tracing.
"""

def method(self, *args, **kwargs):
cache = getattr(self.tracer.root, cache_name)
res = cache.pop(0)
return res

method.__name__ = method_name
bound_method = method.__get__(proxy, proxy.__class__)
setattr(proxy, method_name, bound_method)


def _wrap_method_for_model_tracing(model, method_name, cache_name):
"""
Helper function that sets a recorded torch.Tensor method as a torch.Tensor method that will use the recorded values
during symbolic tracing.
"""

original_method = getattr(torch.Tensor, method_name)

@functools.wraps(original_method)
def method(*args, **kwargs):
cache = getattr(model, cache_name)
res = cache.pop(0)
return res

setattr(torch.Tensor, method_name, method)

if method_name == "size":
setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))


def _monkey_patch_tensor_methods_for_model_recording(model, method_names):
"""
Helper function that patches torch.Tensor methods (specified by the method_names list) to record model inference
before symbolic tracing.
"""
cache_names = dict()
original_methods = dict()
for method_name in method_names:
cache_name = f"cache_{method_name}"
cache_names[method_name] = cache_name
if not hasattr(torch.Tensor, method_name):
logger.info(f"torch.Tensor has no method called {method_name}, skipping patching.")
continue
original_methods[method_name] = getattr(torch.Tensor, method_name)
setattr(torch.Tensor, method_name, _wrap_method_for_model_recording(model, method_name, cache_name))

if method_name == "size":
original_methods["shape"] = torch.Tensor.shape
setattr(torch.Tensor, "shape", property(getattr(torch.Tensor, method_name)))

return cache_names, original_methods


def _reset_tensor_methods(original_methods):
"""Helper function that resets the monkey patched torch.Tensor methods to their original values."""
for name, method in original_methods.items():
setattr(torch.Tensor, name, method)


class HFTracer(Tracer):
"""
Tracer that is able to symbolically trace models from the library (currently BERT, ELECTRA and T5). To do that, it
uses the HFProxy instead of the regular PyTorch torch.fx.Proxy.
Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
regular PyTorch torch.fx.Proxy.
"""

default_methods_to_record = {"__bool__", "size", "dim"}

def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1):
super().__init__()
encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length
decoder_sequence_length = sequence_length[1] if isinstance(sequence_length, (list, tuple)) else -1
decoder_sequence_length = (
sequence_length[1] if isinstance(sequence_length, (list, tuple)) else encoder_sequence_length
)
self.encoder_shape = [batch_size, encoder_sequence_length]
self.decoder_shape = (
[batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape)
)
self.num_choices = num_choices
if self.num_choices > 0:
self.encoder_shape[0] *= self.num_choices
self.encoder_shape = [batch_size, self.num_choices, encoder_sequence_length]
self.decoder_shape = [batch_size, self.num_choices, decoder_sequence_length]

self.prev_module = None
self.recorded_methods = None

def proxy(self, node: Node):
return HFProxy(node, self)
p = HFProxy(node, self)
if self.recorded_methods:
for method_name, cache_name in self.recorded_methods.items():
_create_recorded_proxy_method(p, method_name, cache_name)
return p

def _generate_dummy_input(self, model, input_name):
"""Generates dummy input for model inference recording."""
model_class = model.__class__
device = model.device
inputs_dict = dict()

if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = self.encoder_shape[0]
if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = torch.ones(batch_size, dtype=torch.long, device=device)
elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [
*get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class in [
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
GPT2DoubleHeadsModel,
]:
inputs_dict["labels"] = torch.zeros(self.decoder_shape, dtype=torch.long, device=device)
elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros(self.encoder_shape, dtype=torch.long, device=device)
else:
raise NotImplementedError(f"{model_class} not supported yet.")

elif "mask" in input_name or "ids" in input_name:
shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape
inputs_dict[input_name] = torch.ones(shape, dtype=torch.long, device=device)
else:
shape = self.encoder_shape if "decoder" not in input_name else self.decoder_shape
shape += [model.config.hidden_size]
inputs_dict[input_name] = torch.ones(shape, dtype=torch.float, device=device)

return inputs_dict

def record(self, model, input_names, method_names=None):
"""
Records torch.Tensor method outputs (specified by the method_names list) that will then be used during symbolic
tracing.
"""
if method_names is None:
method_names = self.default_methods_to_record

inputs = dict()
for input_name in input_names:
inputs.update(self._generate_dummy_input(model, input_name))

clone = copy.deepcopy(model)
cache_names, original_methods = _monkey_patch_tensor_methods_for_model_recording(clone, method_names)
self.original_methods = original_methods

clone(**inputs)

_reset_tensor_methods(original_methods)

self.recorded_methods = {
method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(clone, cache_name)
}

for cache_name in self.recorded_methods.values():
setattr(model, cache_name, getattr(clone, cache_name))

def trace(self, root: PreTrainedModel, concrete_args: Optional[Dict[str, Any]] = None, method_names=None) -> Graph:
sig = inspect.signature(root.forward)
input_names = sig.parameters.keys() - concrete_args.keys()

self.record(root, input_names, method_names=method_names)

for method_name, cache_name in self.recorded_methods.items():
_wrap_method_for_model_tracing(root, method_name, cache_name)

graph = super().trace(root, concrete_args=concrete_args)

_reset_tensor_methods(self.original_methods)

return graph

def _insert_module_as_submodule(self, mod):
"""
Expand Down Expand Up @@ -202,6 +309,11 @@ def path_of_module(self, mod: torch.nn.Module) -> str:
self.prev_module = path
return path

def create_arg(self, a: Any) -> Argument:
if isinstance(a, range):
return super().create_arg(list(a))
return super().create_arg(a)


def symbolic_trace(
model: PreTrainedModel,
Expand Down Expand Up @@ -249,6 +361,7 @@ def symbolic_trace(
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}

tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices)

traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)

Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
fx_ready_model_classes = all_model_classes

test_sequence_classification_problem_types = True

Expand Down
Loading

0 comments on commit f4a0d6f

Please sign in to comment.