Skip to content

Commit

Permalink
Type hints added to Speech to Text (#16506)
Browse files Browse the repository at this point in the history
* Type hints added

* return hints added

* Update src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
  • Loading branch information
Dahlbomii and Rocketknight1 authored Apr 19, 2022
1 parent 1efca4e commit 3dd57b1
Showing 1 changed file with 36 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@


import random
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union

import numpy as np
import tensorflow as tf

from ...activations_tf import get_tf_activation, glu
Expand All @@ -29,6 +30,7 @@
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
TFSharedEmbeddings,
keras_serializable,
Expand Down Expand Up @@ -1245,23 +1247,23 @@ def get_decoder(self):
)
def call(
self,
input_features=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_features: Optional[TFModelInputType] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs
):
) -> Union[Tuple, TFSeq2SeqModelOutput]:
outputs = self.model(
input_features=input_features,
attention_mask=attention_mask,
Expand Down Expand Up @@ -1333,24 +1335,24 @@ def set_output_embeddings(self, new_embeddings):
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_features=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
input_features: Optional[TFModelInputType] = None,
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_input_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
decoder_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
cross_attn_head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_outputs: Optional[Union[np.ndarray, tf.Tensor]] = None,
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
decoder_inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs
):
) -> Union[Tuple, TFSeq2SeqLMOutput]:
r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
Expand Down

0 comments on commit 3dd57b1

Please sign in to comment.