Skip to content

Commit

Permalink
[quant][fx] Add support for GRU in fx graph mode quantization (pytorc…
Browse files Browse the repository at this point in the history
…h#91976)

Summary:
might be needed by a meta-internal use case

Test Plan:
python test/test_quantization.py TestQuantizeFxOps.test_rnn

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#91976
Approved by: https://github.com/jcaip
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Jan 13, 2023
1 parent 0bd3fa3 commit ec3941a
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 7 deletions.
4 changes: 2 additions & 2 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7908,8 +7908,8 @@ def test_rnn(self):
if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'):
return
qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
module_type_strs = ['LSTM']
module_types = [torch.nn.LSTM]
module_type_strs = ['LSTM', 'GRU']
module_types = [torch.nn.LSTM, torch.nn.GRU]
niter = 10
sample_input = torch.tensor([[100, -155],
[-155, 100],
Expand Down
17 changes: 17 additions & 0 deletions torch/ao/nn/quantized/dynamic/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,23 @@ def forward(self, input, hx=None):
def from_float(cls, mod):
return super(GRU, cls).from_float(mod)

@classmethod
def from_reference(cls, ref_mod):
assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 "
"exists in LSTM, may need to relax the assumption to support the use case"
qmod = cls(
ref_mod.input_size,
ref_mod.hidden_size,
ref_mod.num_layers,
ref_mod.bias,
ref_mod.batch_first,
ref_mod.dropout,
ref_mod.bidirectional,
# assuming there is layer 0, which should be OK
ref_mod.weight_ih_l0_dtype,
)
qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())
return qmod

class RNNCellBase(torch.nn.Module):
# _FLOAT_MODULE = nn.CellRNNBase
Expand Down
1 change: 1 addition & 0 deletions torch/ao/nn/quantized/reference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
'LSTMCell',
'GRUCell',
'LSTM',
'GRU',
'Embedding',
'EmbeddingBag',
]
3 changes: 2 additions & 1 deletion torch/ao/nn/quantized/reference/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .linear import Linear
from .conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
from .rnn import RNNCell, LSTMCell, GRUCell, LSTM
from .rnn import RNNCell, LSTMCell, GRUCell, LSTM, GRU
from .sparse import Embedding, EmbeddingBag

__all__ = [
Expand All @@ -15,6 +15,7 @@
'LSTMCell',
'GRUCell',
'LSTM',
'GRU',
'Embedding',
'EmbeddingBag',
]
128 changes: 127 additions & 1 deletion torch/ao/nn/quantized/reference/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import _VF
from torch.nn.utils.rnn import PackedSequence

__all__ = ['RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell', 'RNNBase', 'LSTM', 'get_quantized_weight']
__all__ = ['RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell', 'RNNBase', 'LSTM', 'GRU', 'get_quantized_weight']

def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
return tensor.index_select(dim, permutation)
Expand Down Expand Up @@ -488,3 +488,129 @@ def from_float(cls, mod, weight_qparams_dict):
for wn in mod._flat_weights_names:
setattr(ref_mod, wn, getattr(mod, wn))
return ref_mod

class GRU(RNNBase):
""" Reference Quantized GRU Module
We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
to the weight_qparams for that weight
"""
def __init__(self, *args, **kwargs):
if 'proj_size' in kwargs:
raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU")
super().__init__('GRU', *args, **kwargs)

def get_quantized_weight_bias_dict(self):
""" dictionary from flat_weight_name to quantized weight or (unquantized) bias
e.g.
{
"weight_ih_l0": quantized_weight,
"bias_ih_l0": unquantized_bias,
...
}
"""
quantized_weight_bias_dict = {}
for wn in self._flat_weights_names:
if hasattr(self, wn):
if wn.startswith("weight"):
weight_or_bias = get_quantized_weight(self, wn)
else:
weight_or_bias = getattr(self, wn)
else:
weight_or_bias = None
quantized_weight_bias_dict[wn] = weight_or_bias
return quantized_weight_bias_dict

def get_flat_weights(self):
flat_weights = []
for wn in self._flat_weights_names:
if hasattr(self, wn):
weight = getattr(self, wn)
if wn.startswith("weight"):
params = _get_weight_and_quantization_params(self, wn)
weight = _quantize_and_dequantize_weight(*params)
else:
weight = None
flat_weights.append(weight)
return flat_weights

def forward(self, input, hx=None): # noqa: F811
# Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
# only changed self._flat_weights to self.get_flat_weights()
# TODO: maybe we can try inheriting from that class and define get_flat_weights
# as a @property? this might interfere with TorchScript, if we remove that
# requirement in the future we should be able to do this
orig_input = input
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
else:
batch_sizes = None
assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
is_batched = input.dim() == 3
batch_dim = 0 if self.batch_first else 1
if not is_batched:
input = input.unsqueeze(batch_dim)
if hx is not None:
if hx.dim() != 2:
raise RuntimeError(
f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor")
hx = hx.unsqueeze(1)
else:
if hx is not None and hx.dim() != 3:
raise RuntimeError(
f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor")
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None

if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)

self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
result = _VF.gru(input, hx, self.get_flat_weights(), self.bias, self.num_layers,
self.dropout, self.training, self.bidirectional, self.batch_first)
else:
result = _VF.gru(input, batch_sizes, hx, self.get_flat_weights(), self.bias,
self.num_layers, self.dropout, self.training, self.bidirectional)
output = result[0]
hidden = result[1]

# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
if not is_batched:
output = output.squeeze(batch_dim)
hidden = hidden.squeeze(1)

return output, self.permute_hidden(hidden, unsorted_indices)

def _get_name(self):
return "QuantizedGRU(Reference)"

@classmethod
def from_float(cls, mod, weight_qparams_dict):
ref_mod = cls(
mod.input_size,
mod.hidden_size,
mod.num_layers,
mod.bias,
mod.batch_first,
mod.dropout,
mod.bidirectional,
weight_qparams_dict=weight_qparams_dict)
for wn in mod._flat_weights_names:
setattr(ref_mod, wn, getattr(mod, wn))
return ref_mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.intrinsic as nni
import torch.ao.nn.intrinsic.qat as nniqat
import torch.nn.qat as nnqat
import torch.nn.quantized._reference as nnqr
import torch.ao.nn.quantized.reference as nnqr
from collections import namedtuple
from typing import Callable, Dict, List, Union
from .backend_config import (
Expand Down Expand Up @@ -580,7 +580,8 @@ def _get_rnn_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPattern
(nn.GRUCell, nnqr.GRUCell),
(nn.LSTMCell, nnqr.LSTMCell),
(nn.RNNCell, nnqr.RNNCell),
(nn.LSTM, nnqr.LSTM)
(nn.LSTM, nnqr.LSTM),
(nn.GRU, nnqr.GRU)
]:
rnn_op_configs.append(
BackendPatternConfig(rnn_op)
Expand Down
1 change: 1 addition & 0 deletions torch/ao/quantization/fx/_lower_to_native_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigA
nnqr.LSTMCell: nnqd.LSTMCell,
nnqr.RNNCell: nnqd.RNNCell,
nnqr.LSTM: nnqd.LSTM,
nnqr.GRU: nnqd.GRU,
}

# Mapping from reference module class to the replacement weight only quantized module class for lowering
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/fx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def convert_weighted_module(
"weight_ih": weight_qparams_ih,
"weight_hh": weight_qparams_hh,
})
elif isinstance(float_module, torch.nn.LSTM):
elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)):
# format for wq_or_wq_dict (flattened attributes):
# {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...}
for wn in float_module._flat_weights_names:
Expand Down

0 comments on commit ec3941a

Please sign in to comment.