Skip to content

Commit

Permalink
Unimo output loss (PaddlePaddle#3450)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yam0214 committed Nov 3, 2022
1 parent 61acf19 commit 6471ce3
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 65 deletions.
125 changes: 94 additions & 31 deletions paddlenlp/transformers/unimo/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
"""Modeling classes for UNIMO model."""

from tkinter.messagebox import NO
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import TransformerEncoder

from .. import PretrainedModel, register_base_model
from ..model_outputs import CausalLMOutputWithCrossAttentions

__all__ = [
"UNIMOPretrainedModel", 'UNIMOModel', 'UNIMOLMHeadModel',
Expand Down Expand Up @@ -411,7 +413,10 @@ def forward(self,
position_ids=None,
attention_mask=None,
use_cache=False,
cache=None):
cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False):
r"""
The UNIMOModel forward method, overrides the special :meth:`__call__` method.
Expand Down Expand Up @@ -454,14 +459,24 @@ def forward(self,
method. See :meth:`paddle.nn.TransformerEncoder.gen_cache`
method for more details. It is only used for inference and
should be None for training. Defaults to `None`.
output_attentions (bool, optional):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. Defaults to `False`.
output_hidden_states (bool, optional):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` object. If `False`, the output
will be a tuple of tensors. Defaults to `False`.
Returns:
Tensor or tuple: If `use_cache` is False, it is a tensor representing the output of :class:`UNIMOModel`, with
shape [batch_size, sequence_length, hidden_size]. The data type is float64.
Otherwise, it is a tuple, besides the output of :class:`UNIMOModel`, the tuple also includes the new
cache which is same as input `cache` but `incremental_cache` in it has an incremental length.
See :meth:`paddle.nn.MultiHeadAttention.gen_cache` method and
:meth:`paddle.nn.MultiHeadAttention.forward` method for more details.
An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` if
`return_dict=True`. Otherwise it returns a tuple of tensors corresponding
to ordered and not None (depending on the input arguments) fields of
:class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions`.
Especially, When `return_dict=output_hidden_states=output_attentions=False` and `cache=None`,
returns tensor `Sequence_output` of shape [batch_size, sequence_length, hidden_size],
which is the output at the last layer of the model.
Example:
.. code-block::
Expand All @@ -486,15 +501,18 @@ def forward(self,
embedding_output = self.encoder_norm(embedding_output)
embedding_output = self.dropout(embedding_output)

if use_cache:
if cache is None:
cache = self.encoder.gen_cache(embedding_output)
sequence_output, cache = self.encoder(embedding_output,
attention_mask, cache)
return sequence_output, cache
else:
sequence_output = self.encoder(embedding_output, attention_mask)
return sequence_output
if use_cache and cache is None:
cache = self.encoder.gen_cache(embedding_output)

outputs = self.encoder(
embedding_output,
attention_mask,
cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return outputs


class UNIMOLMHead(nn.Layer):
Expand Down Expand Up @@ -555,7 +573,11 @@ def forward(self,
attention_mask=None,
masked_positions=None,
use_cache=False,
cache=None):
cache=None,
labels=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False):
r"""
The UNIMOLMHeadModel forward method, overrides the special
:meth:`__call__` method.
Expand All @@ -573,14 +595,26 @@ def forward(self,
See :class:`UNIMOModel`.
cache (list, optional):
See :class:`UNIMOModel`.
labels (Tensor, optional):
Labels for computing the left-to-right language modeling loss. Indices should be in
`[-100, 0, ..., vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., vocab_size]`
output_attentions (bool, optional):
See :class:`UNIMOModel`.
output_hidden_states (bool, optional):
See :class:`UNIMOModel`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithPastAndCrossAttentions` object. If `False`, the output
will be a tuple of tensors. Defaults to `False`.
Returns:
Tensor or tuple: If `use_cache` is False, it is a tensor representing the output of :class:`UNIMOModel`, with
shape [batch_size, sequence_length, hidden_size]. The data type is float64.
Otherwise, it is a tuple, besides the output of :class:`UNIMOLMHeadModel`, the tuple also includes the new
cache which is same as input `cache` but `incremental_cache` in it has an incremental length.
See :meth:`paddle.nn.MultiHeadAttention.gen_cache` method and
:meth:`paddle.nn.MultiHeadAttention.forward` method for more details.
An instance of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithPastAndCrossAttentions` if
`return_dict=True`. Otherwise it returns a tuple of tensors corresponding
to ordered and not None (depending on the input arguments) fields of
:class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithPastAndCrossAttentions`.
Especially, When `return_dict=output_hidden_states=output_attentions=False` and `cache=labels=None`,
returns tensor `logits` of shape [batch_size, sequence_length, hidden_size],
which is the output at the last layer of the model.
Example:
.. code-block::
Expand All @@ -598,17 +632,46 @@ def forward(self,
logits = model(**inputs)
"""

outputs = self.unimo(input_ids, token_type_ids, position_ids,
attention_mask, use_cache, cache)
outputs = self.unimo(
input_ids,
token_type_ids,
position_ids,
attention_mask,
use_cache,
cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0] if use_cache else outputs
sequence_output = outputs if isinstance(outputs,
type(input_ids)) else outputs[0]

logits = self.lm_head(sequence_output, masked_positions)
if use_cache:
cache = outputs[1]
return logits, cache
else:
return logits

lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
lm_loss = loss_fct(
logits.reshape((-1, self.unimo.config["vocab_size"])),
labels.reshape((-1, )))

if not return_dict:
if isinstance(outputs, type(input_ids)):
return (lm_loss, logits) if lm_loss is not None else logits
else:
outputs = (logits, ) + outputs[1:]
return ((lm_loss, ) +
outputs) if lm_loss is not None else outputs

return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)

def prepare_faster_entry(self, kwargs):
from paddlenlp.ops import FasterUNIMOText
Expand Down
Loading

0 comments on commit 6471ce3

Please sign in to comment.