Skip to content

Commit

Permalink
Experimental symbolic tracing feature with torch.fx for BERT, ELECTRA…
Browse files Browse the repository at this point in the history
… and T5 (huggingface#11475)

Symbolic tracing feature for BERT, ELECTRA and T5

Co-authored-by: Michael Benayoun <michael@huggingface.co>
Co-authored-by: Stas Bekman <stas@stason.org>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
4 people committed May 14, 2021
1 parent 94a2348 commit 86d5fb0
Show file tree
Hide file tree
Showing 7 changed files with 371 additions and 4 deletions.
19 changes: 19 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,15 @@ def is_torch_cuda_available():
return False


_torch_fx_available = False
if _torch_available:
_torch_fx_available = version.parse(_torch_version) >= version.parse("1.8")


def is_torch_fx_available():
return _torch_fx_available


def is_tf_available():
return _tf_available

Expand Down Expand Up @@ -1597,11 +1606,21 @@ def wrapper(*args, **kwargs):
return wrapper


def is_torch_fx_proxy(x):
if is_torch_fx_available():
import torch.fx

return isinstance(x, torch.fx.Proxy)
return False


def is_tensor(x):
"""
Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
:obj:`np.ndarray`.
"""
if is_torch_fx_proxy(x):
return True
if is_torch_available():
import torch

Expand Down
253 changes: 253 additions & 0 deletions src/transformers/modeling_fx_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import dis
import inspect
from typing import List, Optional, Union

import torch
from torch.fx import GraphModule, Node, Proxy, Tracer

from . import PreTrainedModel


class HFProxy(Proxy):
"""
Proxy that is able to provide the proper ranks, shapes and boolean values during symbolic tracing by implementing
the dim, size and __bool__ methods. It can be easily extended by either adding new methods or extending the
existing ones.
"""

def __init__(self, node: Node, tracer: Optional[Tracer] = None):
super().__init__(node, tracer=tracer)
if hasattr(self, "tracer") and self.tracer is not None:
self.device = self.tracer.root.device
self.dtype = next(self.tracer.root.parameters()).dtype

def dim(self):
return len(self.tracer.encoder_shape)

def _shape(self, calling_frame):
module = calling_frame.f_locals.get("self", None)
is_decoder = hasattr(module, "is_decoder") and module.is_decoder
return list(self.tracer.decoder_shape) if is_decoder else list(self.tracer.encoder_shape)

def size(self, dim=None):
frame = inspect.currentframe()
calling_frame = frame.f_back

# self.size can be called through the shape property, in which case we need to get the outer
# frame, containing the meaningful information.
if calling_frame.f_code.co_name == "shape":
calling_frame = calling_frame.f_back

instructions = list(reversed(list(dis.get_instructions(calling_frame.f_code))[: calling_frame.f_lasti]))
code_context = inspect.getframeinfo(calling_frame).code_context[0].strip()

shape = self._shape(calling_frame)

if calling_frame.f_code.co_name == "transpose_for_scores":
# Provides the proper "x.size()" for:
# new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
shape = shape + [-1]
elif "context_layer" in calling_frame.f_locals:
# Provides the proper "context_layer.size()" for:
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
shape = shape + [-1, -1]
elif calling_frame.f_locals.get("do_cross_attention", False):
# Provides the proper shape for:
# query_length = present_key_value_state[0].shape[2]
# (modeling_t5.py)
shape = list(self.tracer.encoder_shape)
shape = shape[:1] + [-1] + shape[1:2]
elif "key_length" in code_context or "encoder_seq_length" in code_context:
shape = list(self.tracer.encoder_shape)
elif "lm_logits.size(-1)" in code_context:
shape = [self.tracer.root.config.vocab_size]
elif "start_positions" in code_context or "end_positions" in code_context:
# For question answering tasks.
shape = [1]
elif "num_choices" in code_context:
if self.tracer.num_choices <= 0:
raise ValueError("num_choices must be given to the CustomTracer for MultipleChoice tasks.")
shape = shape[:1] + [self.tracer.num_choices] + shape[1:]
else:
# Default case:
# - If self.size is called for an unpacking, retrieves the corresponding unpacking
# instruction, and returns the shape padded as much as necessary to match the expected
# number of items.
# - If self.size is called outside of an unpacking context, simply return the shape.
is_unpack = False

for inst in instructions:
if inst.opname == "UNPACK_SEQUENCE":
is_unpack = True
break

if is_unpack and inst.argval >= 3:
shape += [self.tracer.root.config.hidden_size]
dummy_values = [1] * (inst.argval - 3)
shape += dummy_values

if dim is not None:
return shape[dim]

return tuple(shape)

@property
def shape(self):
return self.size()

def __bool__(self) -> bool:
frame = inspect.currentframe()
calling_frame = frame.f_back
code_context = inspect.getframeinfo(calling_frame).code_context[0].strip()
if calling_frame.f_code.co_name == "apply_chunking_to_forward":
# Returning True to every assertion in "apply_chuncking_to_forward"
return True
elif "assert" in code_context:
# Returning True to any assertion.
return True
elif calling_frame.f_code.co_name == "get_extended_attention_mask":
# Corresponding to:
# if causal_mask.shape[1] < attention_mask.shape[1]:
return calling_frame.f_back.f_locals["past_key_values"][0] is not None
raise NotImplementedError("__bool__ was called for CustomProxy, but this case is not covered yet.")

def __setitem__(self, key, value):
pass

def __contains__(self, key):
return False


class HFTracer(Tracer):
"""
Tracer that is able to symbolically trace models from the library (currently BERT, ELECTRA and T5). To do that, it
uses the HFProxy instead of the regular PyTorch torch.fx.Proxy.
"""

def __init__(self, batch_size=1, sequence_length=[128, 128], num_choices=-1):
super().__init__()
encoder_sequence_length = sequence_length[0] if isinstance(sequence_length, (list, tuple)) else sequence_length
decoder_sequence_length = sequence_length[1] if isinstance(sequence_length, (list, tuple)) else -1
self.encoder_shape = [batch_size, encoder_sequence_length]
self.decoder_shape = (
[batch_size, decoder_sequence_length] if decoder_sequence_length > 0 else list(self.encoder_shape)
)
self.num_choices = num_choices
if self.num_choices > 0:
self.encoder_shape[0] *= self.num_choices

self.prev_module = None

def proxy(self, node: Node):
return HFProxy(node, self)

def _insert_module_as_submodule(self, mod):
"""
Helper method which tries to insert a module that was not declared as submodule.
"""
# First, retrieve the parent module.
if self.prev_module is None:
return None
parent_path = self.prev_module.rsplit(".", 1)[0]
parent_mod = None
for path, module in self.root.named_modules():
if path == parent_path:
parent_mod = module
break
if parent_mod is None:
return None

# If retrieving the parent module was possible, set the module not declared as a submodule
# as a parent module attribute.
path = None
for var_name, var_val in inspect.currentframe().f_back.f_locals.items():
if mod is var_val:
setattr(parent_mod, var_name, mod)
path = f"{parent_path}.{var_name}"
break

return path

def path_of_module(self, mod: torch.nn.Module) -> str:
"""
Helper method to find the qualified name of ``mod`` in the Module hierarchy of ``root``. For example, if
``root`` has a submodule named ``foo``, which has a submodule named ``bar``, passing ``bar`` into this function
will return the string "foo.bar".
Args:
mod (str): The ``Module`` to retrieve the qualified name for.
"""
# Prefer the O(1) algorithm
if hasattr(self, "submodule_paths") and self.submodule_paths:
path = self.submodule_paths.get(mod)
if path is None:
path = self._insert_module_as_submodule(mod)
if path is None:
raise NameError("module is not installed as a submodule")
self.prev_module = path
return path

# O(N^2) fallback in the case that we didn't store the submodule
# paths.
else:
for n, p in self.root.named_modules():
if mod is p:
self.prev_module = n
return n
path = self._insert_module_as_submodule(mod)
if path is None:
raise NameError("module is not installed as a submodule")
self.prev_module = path
return path


def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
batch_size: int = 1,
sequence_length: Union[int, List[int]] = [128, 128],
num_choices: int = -1,
) -> GraphModule:

"""
Performs symbolic tracing on the model.
Args:
model (:obj:`PretrainedModel`):
The model to trace.
input_names (:obj:`List[str]`, `optional`):
The names of the inputs of the traced model. If unset, model.dummy_inputs().keys() are used instead.
batch_size (:obj:`int`, `optional`, defaults to 1):
The batch size of the traced model inputs.
sequence_length (:obj:`int` or :obj:`List[int]]`):
The sequence length of the traced model inputs. For sequence-to-sequence models with different sequence
lengths between the encoder and the decoder inputs, this must be :obj:`[encoder_sequence_length,
decoder_sequence_length]`.
num_choices (:obj:`int`, `optional`, defaults to -1):
The number of possible choices for a multiple choice task.
Returns:
:obj:`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
Example::
from transformers.modeling_fx_utils import symbolic_trace
traced_model = symbolic_trace(
model,
input_names=["input_ids", "attention_mask", "token_type_ids"],
batch_size=1,
sequence_length=128,
)
"""
if input_names is None:
input_names = model.dummy_inputs.keys()

sig = inspect.signature(model.forward)
# TODO: how to handle the case of the "return_dict" parameter.
concrete_args = {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}

tracer = HFTracer(batch_size=batch_size, sequence_length=sequence_length, num_choices=num_choices)
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)

return traced
12 changes: 9 additions & 3 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_fx_proxy,
replace_return_docstrings,
)
from ...modeling_outputs import (
Expand Down Expand Up @@ -776,9 +777,14 @@ def _shift_right(self, input_ids):
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"

# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id

assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True

# special case for ForPreTraining model
Expand Down
Loading

0 comments on commit 86d5fb0

Please sign in to comment.