TensorFlow Lite Conversion
TF Lite Conversion Comparison¶
This page provide a guidance of using TFLite to convert and deploy models.
We use LeNet-like CNN model on MNIST dataset. The workflow is general, however the performance of TF Lite model (compression, accuracy) would be different based your models and datasets.
Specifically, I am going to explain the workflow buried in Tensorflow Lite webpage
# !pip install -U tensorflow=2.0.0
!rm -rf *.tflite
!mkdir -p tmp
!rm -rf tmp/*.tflite
%tensorflow_version 1.15
from google.colab import files
import tensorflow as tf
from tensorflow import keras
from tensorflow import lite
import numpy as np
import matplotlib.pylab as plt
from packaging import version
from os import path
import pandas as pd
import os
from IPython.core.display import HTML
import time
%matplotlib inline
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
ver1_flag = version.parse(tf.__version__) < version.parse("2.0")
tf.__version__
Load data¶
Also we create two generator functions, create_data
and create_represent_data
for TFLite usage later.
# load mnist data for testing
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(10000, 28,28, 1).astype('float32') / 255
y_train = y_train.astype('float32')
y_test = y_test.astype('float32')
def create_data(data):
def data_gen():
for i in data:
yield [i]
return data_gen
def create_represent_data(data):
def data_gen():
for i in data:
yield [list([i])]
return data_gen
x_train.shape
Build Keras Model¶
We build a simple CNN model for testing.
keras.backend.clear_session()
m = keras.Sequential([
keras.layers.Conv2D(16, 3, activation='relu', input_shape=(28,28,1)),
keras.layers.BatchNormalization(),
keras.layers.MaxPool2D(),
keras.layers.Conv2D(16, 3, activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.MaxPool2D(),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax', )
])
m.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
if path.isfile("model.h5"): # try to avoid train again, load model if present
m = keras.models.load_model("model.h5")
m.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
else:
m.fit(x_train, y_train, batch_size=128, epochs=10)
m.save("model.h5")
m.evaluate(x_test, y_test)[1] ## accuracy
m = keras.models.load_model("model.h5")
plain_res = m.predict(x_test)
plain_res.shape
sum(np.argmax(plain_res, axis=1) == y_test)/len(y_test) # test accuracy
TF Lite conversion options¶
def get_conv(model_file): # create tflite converter for keras model
"""
Create TFLiteConverter from keras model
"""
if ver1_flag:
conv = lite.TFLiteConverter.from_keras_model_file(model_file)
else:
m = keras.models.load_model(model_file)
conv = lite.TFLiteConverter.from_keras_model(m)
return conv
def get_diff(result1, result2):
"""
compute the difference between two results
"""
assert result1.shape == result2.shape
id1 = np.argmax(result1, axis=1)
id2 = np.argmax(result2, axis=1)
mismatch = sum(id1!=id2)
diff = result1[id1]-result2[id1]
return mismatch, diff
def get_res(filename, data_gen): # get interpreter output
"""
get output from tflite model
filename - tflite model
data_gen - generator for data input x
"""
intp = lite.Interpreter(filename)
intp.allocate_tensors()
for i in data_gen():
intp.set_tensor(intp.get_input_details()[0]['index'], i)
intp.invoke()
yield intp.get_tensor(intp.get_output_details()[0]['index'])
def get_acc(filename):
"get acuracy from tflite model"
data_gen = create_data(x_test)
pred = np.squeeze([i for i in get_res(filename, data_gen)])
return np.sum(np.argmax(pred, axis=1) == y_test) / len(y_test)
def get_res2(filename, data_gen):
"get accuracy and time"
intp = lite.Interpreter(filename)
intp.allocate_tensors()
for i in data_gen():
t = time.monotonic()
intp.set_tensor(intp.get_input_details()[0]['index'], i)
intp.invoke()
t = time.monotonic() - t
yield np.argmax(intp.get_tensor(intp.get_output_details()[0]['index'])), t
def get_acc_and_time(filename):
data_gen = create_data(x_test)
pred = np.squeeze([i for i in get_res2(filename, data_gen)])
return np.sum(pred[:,0]==y_test)/len(y_test), np.mean(pred[:,1]), np.std(pred[:,1]) # acc, mean, std of inference
Collect all options for tflite conversion¶
# for converter.target_spec.supported_types
type_choice = {}
if ver1_flag:
for i in lite.constants.__all__:
type_choice[i.lower()] = [lite.constants.__dict__[i]]
else:
from tensorflow.lite.python import lite_constants as constants
type_choice = {
"float": [constants.FLOAT], # tf.float32
"int8": [constants.INT8], # tf.int8
"int32": [constants.INT32], # tf.int32
"int64": [constants.INT64], # tf.int64
"string": [constants.STRING], # tf.string
"uint8": [constants.QUANTIZED_UINT8], #tf.uint
}
type_choice['none'] = None
# for converter.target_spec.supported_ops
ops_choice = {
"int8": [lite.OpsSet.TFLITE_BUILTINS_INT8],
"tflite": [lite.OpsSet.TFLITE_BUILTINS], # default
"tf": [lite.OpsSet.SELECT_TF_OPS, lite.OpsSet.TFLITE_BUILTINS]
}
opt_choice = {
"default": [lite.Optimize.DEFAULT],
"latency": [lite.Optimize.OPTIMIZE_FOR_LATENCY],
"size": [lite.Optimize.OPTIMIZE_FOR_SIZE],
"none": []
}
# for converter.representative_dataset
data_gen2 = create_represent_data(x_train[:5000])
data_choice = {"with_data": data_gen2, "wo_data": None}
type_choice
# tflite and graphviz_dot are used to control output graph type.
%%capture convert_log
# output has been cleared
res = []
for xdata in data_choice:
for xopt in opt_choice:
for xops in ops_choice:
for xtype in type_choice:
filename = "tmp/%s-opt(%s)-ops(%s)-type(%s).tflite"%(xdata, xopt, xops, xtype)
print("******** %s ********" % filename)
keras.backend.clear_session()
try:
conv = get_conv("model.h5")
conv.optimizations = opt_choice[xopt]
conv.representative_dataset = data_choice[xdata]
conv.target_spec.supported_ops = ops_choice[xops]
conv.target_spec.supported_types = type_choice[xtype]
fb = conv.convert()
msg = ("success")
with open(filename, 'wb') as f:
f.write(fb)
size = path.getsize(filename)
print("finished")
acc = get_acc_and_time(filename)
except Exception as e:
msg = e.__str__()
print("failed - %s"%msg)
size = None
acc = None, None, None
finally:
res.append([xdata, xopt, xops, xtype, size, msg, *acc])
result = pd.DataFrame(res, columns=["data", "optimization", "ops", "type", "size", "status", "accuracy","mean_inference","std_inference"])
result.to_pickle("result.pkl")
files.download("result.pkl")
%%javascript
require.config({
paths: {
DT: '//cdn.datatables.net/1.10.19/js/jquery.dataTables.min',
}
});
Raw results¶
HTML(result.to_html())
Finished tf lite models¶
HTML(result.dropna().to_html())
TF Lite Interpreter result details¶
data_gen = create_data(x_test)
Plain TF Lite Convert¶
plain_tflite = np.squeeze([i for i in get_res("tmp/wo_data-opt(none)-ops(tf)-type(float).tflite", data_gen)])
mismatch, diff = get_diff(plain_res, plain_tflite)
plt.hist(diff)
plt.title("total mismatch=%d"%mismatch)
plt.show()
def print_inpt(filename):
intp = lite.Interpreter(filename)
for i in intp.get_tensor_details():
print("\t".join(["%d"%i['index'], ("%s"%i['dtype']).split("'")[1].split('.')[1], i['name'], "%s"%i['shape'], "(%f,%f)"%i['quantization']]))
print_inpt("tmp/wo_data-opt(none)-ops(tf)-type(float).tflite")
TF default optimization¶
plain_opt = np.squeeze([i for i in get_res("tmp/wo_data-opt(default)-ops(tflite)-type(float).tflite", data_gen)])
mismatch, diff = get_diff(plain_res, plain_opt)
plt.hist(diff)
plt.title("total mismatch=%d"%mismatch)
plt.show()
print_inpt("tmp/wo_data-opt(default)-ops(tflite)-type(float).tflite")
TF with representative data¶
data_opt = np.squeeze([i for i in get_res("tmp/with_data-opt(default)-ops(tflite)-type(float).tflite", data_gen)])
mismatch, diff = get_diff(plain_res, data_opt)
plt.hist(diff)
plt.title("total mismatch=%d"%mismatch)
plt.show()
print_inpt("tmp/with_data-opt(default)-ops(tflite)-type(float).tflite")
Other quick comparison¶
print_inpt("tmp/with_data-opt(default)-ops(int8)-type(int8).tflite")
print_inpt("tmp/with_data-opt(size)-ops(tflite)-type(int8).tflite")
Remarks¶
Take-aways¶
Now, let's take another look of the workflow provided by TF.
- Optimization types are not fully implemented, see [here].(https://github.com/tensorflow/tensorflow/blob/570206441717511720fdae9ac58dac16cc1d348a/tensorflow/lite/python/lite.py#L96)
-
No data, no optimization, ops and types doesn't matter except crashing cases (e.g. int). It will create a float32 tflite for runtime. This corresponds to
N
toOptimzie model?
.- Exception case 1: using int in types/ops
- Exception case 2: with data, and set ops to int, (types is int or none).
-
With optimization and types of
float16
, it will reduce to half size. - With optimization (and without float16), some weights are quantized to int8.
- With data and optimization, weights are in int type. However int8 is not strictly enforced.
- When ops is
int8
, data type needs to beint8
. - int8 and uint8 are quite different.
Remaining mysterious¶
- What's the difference between
select
andbuiltin
? - What's the
string
ornone
op type?
An unexpected problem¶
If you check the source code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/kernel_util.cc#L100), it is under the GetQuantizedConvolutionMultipler
function. So there is some interesting conversion for the fully connected layer. To save the trouble and focus on our original goal.
m = keras.Sequential(
[keras.layers.Dense(100, input_shape=(5,)),
keras.layers.Dense(100),
keras.layers.Dense(3)]
)
m.save("model2.h5")
x = np.random.randn(100,5).astype('float32')
def data_gen():
for i in range(100):
yield x[None, i]
def data_gen2():
y = data_gen()
for i in y:
yield [list(i)]
if ver1_flag:
conv = lite.TFLiteConverter.from_keras_model_file("model2.h5")
else:
conv = lite.TFLiteConverter.from_keras_model(m)
conv.optimizations = [lite.Optimize.DEFAULT]
conv.representative_dataset = data_gen2
with open("problem.tflite", "wb") as f:
f.write(conv.convert())
intp = lite.Interpreter("problem.tflite")
intp.allocate_tensors()
Comments