Comparing TF Checkpoints in v1.x and v2.x
Comparing TF Checkpoints in v1.x and v2.x¶
Outline¶
- How to create a checkpoint?
- What checkpoint looks like?
Creating a checkpoint¶
First, we create the basic LeNet-5 with MNIST as we did previously.
%tensorflow_version 2.x
from tensorflow import keras
import tensorflow as tf2
import tensorflow.compat.v1 as tf1
# comment out the following lines for tf2 example, restart runtime
#tf1.disable_eager_execution()
tf2.__version__
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(10000, 28,28, 1).astype('float32') / 255
y_train = y_train.astype('float32')
y_test = y_test.astype('float32')
keras.backend.clear_session()
def get_model():
model = keras.Sequential([
keras.layers.Conv2D(16, 3, activation='relu', input_shape=(28,28,1)),
keras.layers.BatchNormalization(),
keras.layers.MaxPool2D(),
keras.layers.Conv2D(16, 3, activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.MaxPool2D(),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax', )
])
opt = keras.optimizers.Adam()
model.compile(optimizer=opt,
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
return model, opt
- TF v1.x we create the checkpoint based on the TF doc v1
!rm ./tf_ckpt/*
m, opt = get_model()
saver = tf1.train.Saver()
with tf1.Session() as sess:
for i in range(2):
m.fit(x_train, y_train, batch_size=128)
saver.save(sess, "./tf_ckpt/model")
!ls -lh ./tf_ckpt/
!cat ./tf_ckpt/checkpoint
-
TF v2.x, we create the checkpoint related properties (see tf doc). However, to simplify the process, we don't use the confusing
tf.GradientTape
. Keep it simple, we just usemodel.fit
from keras. Need to restart and comment outdisable_eager_execution
.
!rm tf_ckpt_v2/*
keras.backend.clear_session()
m, opt = get_model()
ckpt = tf2.train.Checkpoint(step=tf2.Variable(1), optimizer=opt, net=m)
manager = tf2.train.CheckpointManager(ckpt, './tf_ckpt_v2', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
for i in range(2):
m.fit(x_train, y_train, batch_size=128)
ckpt.step.assign_add(1)
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
!ls -lh ./tf_ckpt_v2/
!cat ./tf_ckpt_v2/checkpoint
Differences between 1.x and 2.x¶
In TF 1.x versions, the checkpoint contains four file types (excerpt):
-
checkpoint
: checkpoint path index -
*.index
: it is a string-string immutable table(tensorflow::table::Table
). Each key is a name of a tensor and its value is a serializedBundleEntryProto
. EachBundleEntryProto
describes the metadata of a tensor: which of the "data" files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc. -
*.data-*-of*
: it is TensorBundle collection, save the values of all variables. -
*.meta
: describes the saved graph structure, includesGraphDef
,SaverDef
, and so on; then applytf.train.import_meta_graph('/tmp/model.ckpt.meta')
, will restore Saver and Graph.
In TF2.x versions, the meta
file is missing, in line with removing the session
and graph
. Based on TF doc:
Checkpoint.save and Checkpoint.restore write and read object-based checkpoints, in contrast to TensorFlow 1.x's tf.compat.v1.train.Saver which writes and reads variable.name based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restoring a checkpoint. It can be more robust to changes in the Python program, and helps to support restore-on-create for variables.
Future post¶
- How to restore from a checkpoint?
- Investigate checkpoint with graph/node.
Comments