TensorFlow,为什么保存模型后有3个文件?

读过文档后 ,我在TensorFlow保存了一个模型,这里是我的演示代码:

 # Create some variables. v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. .. # Save the variables to disk. save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in file: %s" % save_path) 

但之后,我发现有3个文件

 model.ckpt.data-00000-of-00001 model.ckpt.index model.ckpt.meta 

而且我不能通过恢复model.ckpt文件来恢复模型,因为没有这样的文件。 这是我的代码

 with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "/tmp/model.ckpt") 

那么,为什么有3个文件?

尝试这个:

 with tf.Session() as sess: saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta') saver.restore(sess, "/tmp/model.ckpt") 

TensorFlow保存方法保存三种文件,因为它将图结构variables值分开存储。 .meta文件描述了保存的graphics结构,因此在恢复检查点之前需要导入它(否则它不知道保存的检查点值对应的variables)。

或者,你可以这样做:

 # Recreate the EXACT SAME variables v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Now load the checkpoint variable values with tf.Session() as sess: saver = tf.train.Saver() saver.restore(sess, "/tmp/model.ckpt") 

即使没有名为model.ckpt文件,在恢复时仍然会通过该名称引用保存的检查点。 来自saver.py源代码 :“用户只需要与用户指定的前缀进行交互…而不是任何物理path名称”。

  • 元文件 :描述保存的graphics结构,包括GraphDef,SaverDef等等; 然后应用tf.train.import_meta_graph('/tmp/model.ckpt.meta') ,将恢复SaverGraph

  • 索引文件 :它是一个string不可变表(tensorflow :: table :: Table)。 每个键是一个张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:哪个“数据”文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等

  • 数据文件 :是TensorBundle集合,保存所有variables的值。

我正在从Word2Vec tensorflow教程恢复受过训练的单词embedded。

如果您创build了多个检查点:

例如创build的文件看起来像这样

model.ckpt-55695.data 00000-的-00001

model.ckpt-55695.index

model.ckpt-55695.meta

尝试这个

 def restore_session(self, session): saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta') saver.restore(session, './tmp/model.ckpt-55695') 

当调用restore_session()时:

 def test_word2vec(): opts = Options() with tf.Graph().as_default(), tf.Session() as session: with tf.device("/cpu:0"): model = Word2Vec(opts, session) model.restore_session(session) model.get_embedding("assistance")