TensorFlow Session restore checkpoint

最近写了一些TensorFlow的小程序,遇到了一个session无法restore checkpoint的问题,写法非常简单,使用的是

with tf.train.MonitoredTrainingSession(master = server.target,
                                           is_chief = task_index == 0,
                                           checkpoint_dir= checkpoint_dir,
                                           save_checkpoint_secs=20) as sess:

按理说,tf.train.MonitoredTrainingSession能够save和restore checkpoint,经过测试发现,save是没问题的,但是每次训练都是新的模型,没有持续训练。
经过多次查找,才发现正确的写法。
根据TF官网,save 和restore只需要使用tf.train.Saver()就可以解决问题。但是我根据官网去实现,却报了很多错误,最后发现,通过以下方法才可以,首先要获取latest checkpoint文件,就是最近的checkpoint文件,包括模型参数。
然后将模型restore出来,但是别忘了之前要reset这次训练的graph。
代码如下:

checkpoint_dir = "hdfs://emr-header-1:9000/movie"
saver = tf.train.Saver()
epoch = 0

with tf.train.MonitoredTrainingSession(master = server.target,
                                           is_chief = task_index == 0,
                                           checkpoint_dir= checkpoint_dir,
                                           save_checkpoint_secs=20) as sess:
     tf.reset_default_graph()
     sess.run(init)
     latest_path = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir)
     saver.restore(sess, latest_path)
Print Friendly

jiang yu

Leave a Reply