Tensorflow:如何保存/恢复模型?

在Tensorflow中训练模型之后:

  1. 你如何保存训练有素的模型?
  2. 你以后如何恢复这个保存的模型?

我正在改进我的答案,以添加更多的细节保存和恢复模型。

在(和之后)Tensorflow版本0.11:

保存模型:

import tensorflow as tf #Prepare to feed input, ie feed_dict and placeholders w1 = tf.placeholder("float", name="w1") w2 = tf.placeholder("float", name="w2") b1= tf.Variable(2.0,name="bias") feed_dict ={w1:4,w2:8} #Define a test operation that we will restore w3 = tf.add(w1,w2) w4 = tf.multiply(w3,b1,name="op_to_restore") sess = tf.Session() sess.run(tf.global_variables_initializer()) #Create a saver object which will save all the variables saver = tf.train.Saver() #Run the operation by feeding input print sess.run(w4,feed_dict) #Prints 24 which is sum of (w1+w2)*b1 #Now, save the graph saver.save(sess, 'my_test_model',global_step=1000) 

还原模型:

 import tensorflow as tf sess=tf.Session() #First let's load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess,tf.train.latest_checkpoint('./')) # Access saved Variables directly print(sess.run('bias:0')) # This will print 2, which is the value of bias that we saved # Now, let's access and create placeholders variables and # create feed-dict to feed new data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict ={w1:13.0,w2:17.0} #Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0") print sess.run(op_to_restore,feed_dict) #This will print 60 which is calculated 

这里和一些更高级的用例在这里已经很好的解释了。

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

在(及之后的)TensorFlow版本0.11.0RC1中,您可以通过根据https://www.tensorflow.org/programmers_guide/meta_graph调用;tf.train.export_meta_graphtf.train.import_meta_graph来直接保存和恢复您的模型。

保存模型

 w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1') w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2') tf.add_to_collection('vars', w1) tf.add_to_collection('vars', w2) saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.save(sess, 'my-model') # `save` method will call `export_meta_graph` implicitly. # you will get saved graph files:my-model.meta 

还原模型

 sess = tf.Session() new_saver = tf.train.import_meta_graph('my-model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) all_vars = tf.get_collection('vars') for v in all_vars: v_ = sess.run(v) print(v_) 

对于TensorFlow版本<0.11.0RC1:

保存的检查点包含模型中Variable的值,而不包含模型/graphics本身,这意味着在恢复检查点时graphics应该是相同的。

下面是一个线性回归的例子,其中有一个保存可变检查点的训练循环和一个评估部分,可以恢复先前运行中保存的variables并计算预测。 当然,你也可以恢复variables,并继续训练,如果你愿意。

 x = tf.placeholder(tf.float32) y = tf.placeholder(tf.float32) w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32)) b = tf.Variable(tf.ones([1, 1], dtype=tf.float32)) y_hat = tf.add(b, tf.matmul(x, w)) ...more setup for optimization and what not... saver = tf.train.Saver() # defaults to saving all variables - in this case w and b with tf.Session() as sess: sess.run(tf.initialize_all_variables()) if FLAGS.train: for i in xrange(FLAGS.training_steps): ...training loop... if (i + 1) % FLAGS.checkpoint_steps == 0: saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt', global_step=i+1) else: # Here's where you're restoring the variables w and b. # Note that the graph is exactly as it was when the variables were # saved in a prior training run. ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: ...no checkpoint found... # Now you can run the model to get predictions batch_x = ...load some data... predictions = sess.run(y_hat, feed_dict={x: batch_x}) 

这里是Variable的文档 ,涵盖保存和恢复。 这里是Saver的文档 。

模型有两部分,模型定义,由Supervisor作为graph.pbtxt保存在模型目录中,张量的数值保存到model.ckpt-1003418类的检查点文件中。

使用tf.import_graph_def可以恢复模型定义,并使用Saver来恢复权重。

但是, Saver使用附加在模型Graph上的特殊集合保存variables列表,并且这个集合没有使用import_graph_def进行初始化,所以现在不能一起使用这两个variables(这是我们的路线图)。 现在,您必须使用Ryan Sepassi的方法 – 手动构build具有相同节点名称的graphics,并使用Saver将权重加载到其中。

(或者你可以通过使用import_graph_def ,手动创buildvariables,并为每个variables使用tf.add_to_collection(tf.GraphKeys.VARIABLES, variable) ,然后使用Saver来破解它)

你也可以采取这个更简单的方法。

第1步:初始化所有variables

 W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1") B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1") Similarly, W2, B2, W3, ..... 

步骤2:将会话保存在模型保存Saver并保存

 model_saver = tf.train.Saver() # Train the model and save it in the end model_saver.save(session, "saved_models/CNN_New.ckpt") 

第3步:恢复模型

 with tf.Session(graph=graph_cnn) as session: model_saver.restore(session, "saved_models/CNN_New.ckpt") print("Model restored.") print('Initialized') 

第四步:检查你的variables

 W1 = session.run(W1) print(W1) 

在不同的python实例中运行时,使用

 with tf.Session() as sess: # Restore latest checkpoint saver.restore(sess, tf.train.latest_checkpoint('saved_model/.')) # Initalize the variables sess.run(tf.global_variables_initializer()) # Get default graph (supply your custom graph if you have one) graph = tf.get_default_graph() # It will give tensor object W1 = graph.get_tensor_by_name('W1:0') # To get the value (numpy array) W1_value = session.run(W1) 

正如雅罗斯拉夫说,你可以通过导入graphics,手动创buildvariables,然后使用Saver来破解graph_def和checkpoint。

我实现了这个为我个人使用,所以我虽然我会在这里分享代码。

链接: https : //gist.github.com/nikitakit/6ef3b72be67b86cb7868

(当然,这是一个黑客攻击,不能保证以这种方式保存的模型在未来的TensorFlow版本中仍然可读。)

在大多数情况下,使用tf.train.Saver保存和从磁盘恢复是最好的select:

 ... # build your model saver = tf.train.Saver() with tf.Session() as sess: ... # train the model saver.save(sess, "/tmp/my_great_model") with tf.Session() as sess: saver.restore(sess, "/tmp/my_great_model") ... # use the model 

您也可以保存/恢复graphics结构本身(有关详细信息,请参阅MetaGraph文档 )。 默认情况下,保存器将graphics结构保存到.meta文件中。 你可以调用import_meta_graph()来恢复它。 它将恢复graphics结构并返回一个可用于恢复模型状态的Saver

 saver = tf.train.import_meta_graph("/tmp/my_great_model.meta") with tf.Session() as sess: saver.restore(sess, "/tmp/my_great_model") ... # use the model 

但是,有些情况下您需要更快的速度。 例如,如果实现提前停止,则每次在培训期间模型改进时(如在validation集上测量),都希望保存检查点,如果在一段时间内没有进展,则要回滚到最佳模型。 如果每次改进都将模型保存到磁盘,则会极大地减慢训练速度。 诀窍是将variables状态保存到内存中 ,然后稍后恢复:

 ... # build your model # get a handle on the graph nodes we need to save/restore the model graph = tf.get_default_graph() gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars] init_values = [assign_op.inputs[1] for assign_op in assign_ops] with tf.Session() as sess: ... # train the model # when needed, save the model state to memory gvars_state = sess.run(gvars) # when needed, restore the model state feed_dict = {init_value: val for init_value, val in zip(init_values, gvars_state)} sess.run(assign_ops, feed_dict=feed_dict) 

快速说明:当您创buildvariablesX ,TensorFlow会自动创build一个赋值操作X/Assign来设置variables的初始值。 我们只是使用这些现有的赋值操作,而不是创build占位符和额外的赋值操作(这只会使graphics变得凌乱)。 每个赋值op的第一个input是对它应该初始化的variables的引用,第二个input( assign_op.inputs[1] )是初始值。 因此,为了设置我们想要的任何值(而不是初始值),我们需要使用feed_dict并replace初始值。 是的,TensorFlow让你为任何操作提供一个值,而不仅仅是占位符,所以这工作正常。

如果它是内部保存的模型,则只需为所有variables指定一个恢复器即可

 restorer = tf.train.Saver(tf.all_variables()) 

并用它来恢复当前会话中的variables:

 restorer.restore(self._sess, model_file) 

对于外部模型,您需要指定从其variables名称到variables名称的映射。 您可以使用该命令查看模型variables名称

 python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt 

inspect_checkpoint.py脚本可以在Tensorflow源文件的'./tensorflow/python/tools'文件夹中find。

要指定映射,可以使用我的Tensorflow-Worklab ,它包含一组类和脚本来训练和再培训不同的模型。 它包括一个ResNet ResNet模型的例子,位于这里

这里是我的两个基本案例的简单解决scheme,不同之处在于是否要从文件加载graphics或在运行时构build它。

这个答案适用于Tensorflow 0.12+(含1.0)。

在代码中重buildgraphics

保存

 graph = ... # build the graph saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.save(sess, 'my-model') 

载入中

 graph = ... # build the graph saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.restore(sess, tf.train.latest_checkpoint('./')) # now you can use the graph, continue training or whatever 

从文件加载graphics

使用这种技术时,确保所有图层/variables都明确地设置了唯一的名称。 否则,Tensorflow会使名称本身具有唯一性,并因此与存储在文件中的名称不同。 在以前的技术中这不是问题,因为在加载和保存时名称都是“相同的”。

保存

 graph = ... # build the graph for op in [ ... ]: # operators you want to use after restoring the model tf.add_to_collection('ops_to_restore', op) saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.save(sess, 'my-model') 

载入中

 with ... as sess: # your session object saver = tf.train.import_meta_graph('my-model.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) ops = tf.get_collection('ops_to_restore') # here are your operators in the same order in which you saved them to the collection 

您还可以在TensorFlow / skflow中查看示例 ,该示例提供了可帮助您轻松pipe理模型的saverestore方法。 它具有参数,您还可以控制备份模型的频率。

如果使用tf.train.MonitoredTrainingSession作为默认会话,则不需要添加额外的代码来执行保存/恢复操作。 只要将检查点目录名称传递给MonitoredTrainingSession的构造函数,它将使用会话挂钩来处理这些。

如问题6255中所述 :

 use '**./**model_name.ckpt' saver.restore(sess,'./my_model_final.ckpt') 

代替

 saver.restore('my_model_final.ckpt') 

这里的所有答案都很好,但我想补充两件事情。

首先,为了详细说明@ user7505159的答案,“./”可以添加到要恢复的文件名的开头。

例如,您可以在文件名中保存一个不带“./”的graphics,如下所示:

 # Some graph defined up here with specific names saver = tf.train.Saver() save_file = 'model.ckpt' with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess, save_file) 

但是为了恢复graphics,可能需要在file_name中加上“./”:

 # Same graph defined up here saver = tf.train.Saver() save_file = './' + 'model.ckpt' # String addition used for emphasis with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, save_file) 

你不会总是需要“./”,但是它可能会导致问题,这取决于你的环境和TensorFlow的版本。

还要提到在恢复会话之前, sess.run(tf.global_variables_initializer())可能很重要。

如果在尝试还原保存的会话时收到有关未初始化variables的错误,请确保在saver.restore(sess, save_file)行之前包含sess.run(tf.global_variables_initializer()) 。 它可以节省你头痛。

我的环境:Python 3.6,Tensorflow 1.3.0

虽然有很多解决scheme,但大部分都是基于tf.train.Saver 。 当我们加载由Saver保存的.ckpt ,我们必须重新定义tensorflownetworking或使用一些奇怪而难记的名字,例如'placehold_0:0''dense/Adam/Weight:0' 。 在这里我推荐使用tf.saved_model ,下面给出一个最简单的例子,你可以从服务一个TensorFlow模型了解更多:

保存模型:

 import tensorflow as tf # define the tensorflow network and do some trains x = tf.placeholder("float", name="x") w = tf.Variable("float", name="w") b = tf.Variable(0.0, name="bias") h = tf.multiply(x, w) y = tf.add(h, b, name="y") sess = tf.Session() sess.run(tf.global_variables_initializer()) # save the model export_path = './savedmodel' builder = tf.saved_model.builder.SavedModelBuilder(export_path) tensor_info_x = tf.saved_model.utils.build_tensor_info(x) tensor_info_y = tf.saved_model.utils.build_tensor_info(y) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'x_input': tensor_info_x}, outputs={'y_output': tensor_info_y}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature }, ) 

加载模型:

 import tensorflow as tf sess=tf.Session() signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY input_key = 'x_input' output_key = 'y_output' export_path = './savedmodel' meta_graph_def = tf.saved_model.loader.load( sess, [tf.saved_model.tag_constants.SERVING], export_path) signature = meta_graph_def.signature_def x_tensor_name = signature[signature_key].inputs[input_key].name y_tensor_name = signature[signature_key].outputs[output_key].name x = sess.graph.get_tensor_by_name(x_tensor_name) y = sess.graph.get_tensor_by_name(y_tensor_name) y_out = sess.run(y, {x: 3.0})