Comparing TF Checkpoints in v1.x and v2.x

Comparing TF Checkpoints in v1.x and v2.x

Outline

  • How to create a checkpoint?
  • What checkpoint looks like?

Creating a checkpoint

First, we create the basic LeNet-5 with MNIST as we did previously.

In [0]:
%tensorflow_version 2.x
from tensorflow import keras
import tensorflow as tf2
import tensorflow.compat.v1 as tf1
# comment out the following lines for tf2 example, restart runtime
#tf1.disable_eager_execution()

tf2.__version__
TensorFlow 2.x selected.
Out[0]:
'2.1.0-rc1'
In [0]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(10000, 28,28, 1).astype('float32') / 255
y_train = y_train.astype('float32')
y_test = y_test.astype('float32')

keras.backend.clear_session()
def get_model():
  model = keras.Sequential([
                       keras.layers.Conv2D(16, 3, activation='relu', input_shape=(28,28,1)),
                       keras.layers.BatchNormalization(),
                       keras.layers.MaxPool2D(),
                       keras.layers.Conv2D(16, 3, activation='relu'),
                       keras.layers.BatchNormalization(),
                       keras.layers.MaxPool2D(),
                       keras.layers.Flatten(),
                       keras.layers.Dense(128, activation='relu'),
                       keras.layers.Dense(10, activation='softmax', )
  ])

  opt = keras.optimizers.Adam()
  model.compile(optimizer=opt,
          loss=keras.losses.SparseCategoricalCrossentropy(),
          metrics=[keras.metrics.SparseCategoricalAccuracy()])
  return model, opt
  • TF v1.x we create the checkpoint based on the TF doc v1
In [0]:
!rm ./tf_ckpt/*
m, opt = get_model()
saver = tf1.train.Saver()
with tf1.Session() as sess:
  for i in range(2):
    m.fit(x_train, y_train, batch_size=128)
    saver.save(sess, "./tf_ckpt/model")
rm: cannot remove './tf_ckpt/*': No such file or directory
WARNING:tensorflow:From /tensorflow-2.1.0/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1635: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Train on 60000 samples
60000/60000 [==============================] - 33s 548us/sample - loss: 0.1790 - sparse_categorical_accuracy: 0.9449
Train on 60000 samples
60000/60000 [==============================] - 32s 534us/sample - loss: 0.0484 - sparse_categorical_accuracy: 0.9847
In [0]:
!ls -lh ./tf_ckpt/
total 396K
-rw-r--r-- 1 root root   67 Dec 18 03:31 checkpoint
-rw-r--r-- 1 root root 216K Dec 18 03:31 model.data-00000-of-00001
-rw-r--r-- 1 root root  629 Dec 18 03:31 model.index
-rw-r--r-- 1 root root 172K Dec 18 03:31 model.meta
In [0]:
!cat ./tf_ckpt/checkpoint
model_checkpoint_path: "model"
all_model_checkpoint_paths: "model"
  • TF v2.x, we create the checkpoint related properties (see tf doc). However, to simplify the process, we don't use the confusing tf.GradientTape. Keep it simple, we just use model.fit from keras. Need to restart and comment out disable_eager_execution.
In [0]:
!rm tf_ckpt_v2/*
keras.backend.clear_session()
m, opt = get_model()
ckpt = tf2.train.Checkpoint(step=tf2.Variable(1), optimizer=opt, net=m)
manager = tf2.train.CheckpointManager(ckpt, './tf_ckpt_v2', max_to_keep=3)
rm: cannot remove 'tf_ckpt_v2/*': No such file or directory
In [0]:
ckpt.restore(manager.latest_checkpoint)
for i in range(2):
  m.fit(x_train, y_train, batch_size=128)
  ckpt.step.assign_add(1)
  save_path = manager.save()
  print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
Train on 60000 samples
60000/60000 [==============================] - 32s 540us/sample - loss: 0.1831 - sparse_categorical_accuracy: 0.9452
Saved checkpoint for step 2: ./tf_ckpt_v2/ckpt-1
Train on 60000 samples
60000/60000 [==============================] - 32s 539us/sample - loss: 0.0466 - sparse_categorical_accuracy: 0.9857
Saved checkpoint for step 3: ./tf_ckpt_v2/ckpt-2
In [0]:
!ls -lh ./tf_ckpt_v2/
total 1.3M
-rw-r--r-- 1 root root  254 Dec 18 03:35 checkpoint
-rw-r--r-- 1 root root 653K Dec 18 03:34 ckpt-1.data-00000-of-00001
-rw-r--r-- 1 root root 3.1K Dec 18 03:34 ckpt-1.index
-rw-r--r-- 1 root root 653K Dec 18 03:35 ckpt-2.data-00000-of-00001
-rw-r--r-- 1 root root 3.1K Dec 18 03:35 ckpt-2.index
In [0]:
!cat ./tf_ckpt_v2/checkpoint
model_checkpoint_path: "ckpt-2"
all_model_checkpoint_paths: "ckpt-1"
all_model_checkpoint_paths: "ckpt-2"
all_model_checkpoint_timestamps: 1576640070.138009
all_model_checkpoint_timestamps: 1576640102.5460615
last_preserved_timestamp: 1576640036.6357675

Differences between 1.x and 2.x

In TF 1.x versions, the checkpoint contains four file types (excerpt):

  • checkpoint: checkpoint path index
  • *.index: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which of the "data" files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.
  • *.data-*-of*: it is TensorBundle collection, save the values of all variables.
  • *.meta: describes the saved graph structure, includes GraphDef, SaverDef, and so on; then apply tf.train.import_meta_graph('/tmp/model.ckpt.meta'), will restore Saver and Graph.

In TF2.x versions, the meta file is missing, in line with removing the session and graph. Based on TF doc:

Checkpoint.save and Checkpoint.restore write and read object-based checkpoints, in contrast to TensorFlow 1.x's tf.compat.v1.train.Saver which writes and reads variable.name based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restoring a checkpoint. It can be more robust to changes in the Python program, and helps to support restore-on-create for variables.

Future post

  • How to restore from a checkpoint?
  • Investigate checkpoint with graph/node.

Comments