Comparing TF Checkpoints in v1.x and v2.x
Comparing TF Checkpoints in v1.x and v2.x¶
Outline¶
- How to create a checkpoint?
- What checkpoint looks like?
Creating a checkpoint¶
First, we create the basic LeNet-5 with MNIST as we did previously.
%tensorflow_version 2.x
from tensorflow import keras
import tensorflow as tf2
import tensorflow.compat.v1 as tf1
# comment out the following lines for tf2 example, restart runtime
#tf1.disable_eager_execution()
tf2.__version__
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(10000, 28,28, 1).astype('float32') / 255
y_train = y_train.astype('float32')
y_test = y_test.astype('float32')
keras.backend.clear_session()
def get_model():
model = keras.Sequential([
keras.layers.Conv2D(16, 3, activation='relu', input_shape=(28,28,1)),
keras.layers.BatchNormalization(),
keras.layers.MaxPool2D(),
keras.layers.Conv2D(16, 3, activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.MaxPool2D(),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax', )
])
opt = keras.optimizers.Adam()
model.compile(optimizer=opt,
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
return model, opt
- TF v1.x we create the checkpoint based on the TF doc v1
!rm ./tf_ckpt/*
m, opt = get_model()
saver = tf1.train.Saver()
with tf1.Session() as sess:
for i in range(2):
m.fit(x_train, y_train, batch_size=128)
saver.save(sess, "./tf_ckpt/model")
!ls -lh ./tf_ckpt/
!cat ./tf_ckpt/checkpoint
-
TF v2.x, we create the checkpoint related properties (see tf doc). However, to simplify the process, we don't use the confusing
tf.GradientTape. Keep it simple, we just usemodel.fitfrom keras. Need to restart and comment outdisable_eager_execution.
!rm tf_ckpt_v2/*
keras.backend.clear_session()
m, opt = get_model()
ckpt = tf2.train.Checkpoint(step=tf2.Variable(1), optimizer=opt, net=m)
manager = tf2.train.CheckpointManager(ckpt, './tf_ckpt_v2', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
for i in range(2):
m.fit(x_train, y_train, batch_size=128)
ckpt.step.assign_add(1)
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
!ls -lh ./tf_ckpt_v2/
!cat ./tf_ckpt_v2/checkpoint
Differences between 1.x and 2.x¶
In TF 1.x versions, the checkpoint contains four file types (excerpt):
-
checkpoint: checkpoint path index -
*.index: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serializedBundleEntryProto. EachBundleEntryProtodescribes the metadata of a tensor: which of the "data" files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc. -
*.data-*-of*: it is TensorBundle collection, save the values of all variables. -
*.meta: describes the saved graph structure, includesGraphDef,SaverDef, and so on; then applytf.train.import_meta_graph('/tmp/model.ckpt.meta'), will restore Saver and Graph.
In TF2.x versions, the meta file is missing, in line with removing the session and graph. Based on TF doc:
Checkpoint.save and Checkpoint.restore write and read object-based checkpoints, in contrast to TensorFlow 1.x's tf.compat.v1.train.Saver which writes and reads variable.name based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restoring a checkpoint. It can be more robust to changes in the Python program, and helps to support restore-on-create for variables.
Future post¶
- How to restore from a checkpoint?
- Investigate checkpoint with graph/node.
Comments