Skip to content

Commit

Permalink
Merge pull request TensorSpeech#84 from jaeyoo/patch-2
Browse files Browse the repository at this point in the history
🚀 📝 Add TFLite-convertible TFFastSpeech
  • Loading branch information
dathudeptrai authored Jul 3, 2020
2 parents b8c22a6 + a8fc65c commit 300b401
Showing 1 changed file with 124 additions and 50 deletions.
174 changes: 124 additions & 50 deletions tensorflow_tts/models/fastspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,14 +573,14 @@ def call(self, inputs, training=False):
class TFFastSpeechLengthRegulator(tf.keras.layers.Layer):
"""FastSpeech lengthregulator module."""

def __init__(self, config, **kwargs):
def __init__(self, config, enable_tflite_convertible = False, **kwargs):
"""Init variables."""
super().__init__(**kwargs)
self.config = config
self.enable_tflite_convertible = enable_tflite_convertible

def call(self, inputs, training=False):
"""Call logic.
Args:
1. encoder_hidden_states, Tensor (float32) shape [batch_size, length, hidden_size]
2. durations_gt, Tensor (float32/int32) shape [batch_size, length]
Expand All @@ -601,75 +601,93 @@ def _length_regulator(self, encoder_hidden_states, durations_gt):
hidden_size = input_shape[-1]

# initialize output hidden states and encoder masking.
outputs = tf.zeros(shape=[0, max_durations, hidden_size], dtype=tf.float32)
encoder_masks = tf.zeros(shape=[0, max_durations], dtype=tf.int32)

def condition(
i,
batch_size,
outputs,
encoder_masks,
encoder_hidden_states,
durations_gt,
max_durations,
):
return tf.less(i, batch_size)

def body(
i,
batch_size,
outputs,
encoder_masks,
encoder_hidden_states,
durations_gt,
max_durations,
):
repeats = durations_gt[i]
if self.enable_tflite_convertible:
# There is only 1 batch in inference, so we don't have to use
# `tf.While` op with 3-D output tensor.
repeats = durations_gt[0]
real_length = tf.reduce_sum(repeats)
pad_size = max_durations - real_length
# masks : [max_durations]
masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32)
repeat_encoder_hidden_states = tf.repeat(
encoder_hidden_states[i], repeats=repeats, axis=0
encoder_hidden_states[0], repeats=repeats, axis=0
)
repeat_encoder_hidden_states = tf.expand_dims(
tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0
) # [1, max_durations, hidden_size]
outputs = tf.concat([outputs, repeat_encoder_hidden_states], axis=0)
encoder_masks = tf.concat([encoder_masks, masks], axis=0)
return [
i + 1,
) # [1, max_durations, hidden_size]

outputs = repeat_encoder_hidden_states
encoder_masks = masks
else:
outputs = tf.zeros(shape=[0, max_durations, hidden_size], dtype=tf.float32)
encoder_masks = tf.zeros(shape=[0, max_durations], dtype=tf.int32)

def condition(
i,
batch_size,
outputs,
encoder_masks,
encoder_hidden_states,
durations_gt,
max_durations,
]
):
return tf.less(i, batch_size)

# initialize iteration i.
i = tf.constant(0, dtype=tf.int32)
_, _, outputs, encoder_masks, _, _, _, = tf.while_loop(
condition,
body,
[
def body(
i,
batch_size,
outputs,
encoder_masks,
encoder_hidden_states,
durations_gt,
max_durations,
],
shape_invariants=[
i.get_shape(),
batch_size.get_shape(),
tf.TensorShape([None, None, self.config.hidden_size]),
tf.TensorShape([None, None]),
encoder_hidden_states.get_shape(),
durations_gt.get_shape(),
max_durations.get_shape(),
],
)
):
repeats = durations_gt[i]
real_length = tf.reduce_sum(repeats)
pad_size = max_durations - real_length
masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32)
repeat_encoder_hidden_states = tf.repeat(
encoder_hidden_states[i], repeats=repeats, axis=0
)
repeat_encoder_hidden_states = tf.expand_dims(
tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0
) # [1, max_durations, hidden_size]
outputs = tf.concat([outputs, repeat_encoder_hidden_states], axis=0)
encoder_masks = tf.concat([encoder_masks, masks], axis=0)
return [
i + 1,
batch_size,
outputs,
encoder_masks,
encoder_hidden_states,
durations_gt,
max_durations,
]

# initialize iteration i.
i = tf.constant(0, dtype=tf.int32)
_, _, outputs, encoder_masks, _, _, _, = tf.while_loop(
condition,
body,
[
i,
batch_size,
outputs,
encoder_masks,
encoder_hidden_states,
durations_gt,
max_durations,
],
shape_invariants=[
i.get_shape(),
batch_size.get_shape(),
tf.TensorShape([None, None, self.config.hidden_size]),
tf.TensorShape([None, None]),
encoder_hidden_states.get_shape(),
durations_gt.get_shape(),
max_durations.get_shape(),
],
)

return outputs, encoder_masks

Expand Down Expand Up @@ -799,3 +817,59 @@ def inference(self, input_ids, attention_mask, speaker_ids, speed_ratios):

outputs = (mel_before, mel_after, duration_outputs)
return outputs

@tf.function(
experimental_relax_shapes=True,
input_signature=[
tf.TensorSpec(shape=[1, None], dtype=tf.int32),
tf.TensorSpec(shape=[1, None], dtype=tf.bool),
tf.TensorSpec(shape=[1,], dtype=tf.int32),
tf.TensorSpec(shape=[1,], dtype=tf.float32),
],
)
def inference_tflite(self, input_ids, attention_mask, speaker_ids, speed_ratios):
"""Call logic."""
embedding_output = self.embeddings([input_ids, speaker_ids], training=False)
encoder_output = self.encoder(
[embedding_output, attention_mask], training=False
)
last_encoder_hidden_states = encoder_output[0]

# duration predictor, here use last_encoder_hidden_states, u can use more hidden_states layers
# rather than just use last_hidden_states of encoder for duration_predictor.
duration_outputs = self.duration_predictor(
[last_encoder_hidden_states, attention_mask]
) # [batch_size, length]
duration_outputs = tf.math.exp(duration_outputs) - 1.0

if speed_ratios is None:
speed_ratios = tf.convert_to_tensor(np.array([1.0]), dtype=tf.float32)

duration_outputs = tf.cast(
tf.math.round(duration_outputs * speed_ratios), tf.int32
)

length_regulator_outputs, encoder_masks = self.length_regulator(
[last_encoder_hidden_states, duration_outputs], training=False
)

# create decoder positional embedding
decoder_pos = tf.range(
1, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32
)
masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks

decoder_output = self.decoder(
[length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos],
training=False,
)
last_decoder_hidden_states = decoder_output[0]

# here u can use sum or concat more than 1 hidden states layers from decoder.
mel_before = self.mel_dense(last_decoder_hidden_states)
mel_after = (
self.postnet([mel_before, encoder_masks], training=False) + mel_before
)

outputs = (mel_before, mel_after, duration_outputs)
return outputs

0 comments on commit 300b401

Please sign in to comment.