Replace ELU

Katsuya Hyodo edited this page Dec 18, 2020
import tensorflow as tf
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Add, ReLU, MaxPool2D, Reshape, Concatenate, ZeroPadding2D, Layer
from tensorflow.keras.initializers import Constant
from tensorflow.keras.backend import resize_images
from tensorflow.keras.activations import tanh
from tensorflow.math import sigmoid
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import numpy as np
import sys
import tensorflow_datasets as tfds

# x = -0.1
alpha = 1.0

inputs = Input(shape=(10, 10, 3), batch_size=1, name='input')

# tf.math.maximum(0.0, x) + alpha * (tf.math.exp(tf.math.minimum(0.0, x)) - 1) -> <tf.Tensor: shape=(), dtype=float32, numpy=-0.09516257>
# tf.math.maximum(0.0, x) + alpha * (tf.math.pow(2.71828182845904, tf.math.minimum(0.0, x)) - 1) -> 

# op1 = tf.math.maximum(0.0, inputs) + alpha * (tf.math.exp(tf.math.minimum(0.0, inputs)) - 1) # pattern1
# op1 = tf.math.maximum(0.0, inputs) + alpha * (tf.math.pow(2.71828182845904, tf.math.minimum(0.0, inputs)) - 1) # pattern2
op1 = tf.nn.elu(inputs) # pattern3

model = Model(inputs=[inputs], outputs=[op1])
model.summary(), 'saved_model_10x10')

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('saved_model_10x10_float32.tflite', 'wb') as w:
print('tflite convert complete! - saved_model_10x10_float32.tflite')

def representative_dataset_gen():
    for data in raw_test_data.take(10):
        image = data['image'].numpy()
        image = tf.image.resize(image, (10, 10))
        image = image[np.newaxis,:,:,:]
        image = image / 255
        yield [image]
raw_test_data, info = tfds.load(name="coco/2017", with_info=True, split="test", data_dir="~/TFDS", download=False)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8, tf.lite.OpsSet.SELECT_TF_OPS]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
with open('saved_model_10x10_full_integer_quant.tflite', 'wb') as w:
print('Full Integer Quantization complete! - saved_model_10x10_full_integer_quant.tflite')
