tf.app.run()如何工作?

tf.app.run()如何在Tensorflow翻译演示中工作? 在tensorflow / models / rnn / translate / translate.py中,有一个对tf.app.run()的调用。 它是如何处理的?

if __name__ == "__main__": tf.app.run() 

这只是一个非常快速的包装,处理标志parsing,然后分派到您自己的主。 看代码

 if __name__ == "__main__": 

意味着当前文件在shell下执行,而不是作为模块导入。

 tf.app.run() 

正如你可以看到通过文件app.py

 def run(main=None, argv=None): """Runs the program with an optional 'main' function and 'argv' list.""" f = flags.FLAGS # Extract the args from the optional `argv` list. args = argv[1:] if argv else None # Parse the known flags from that list, or from the command # line otherwise. # pylint: disable=protected-access flags_passthrough = f._parse_flags(args=args) # pylint: enable=protected-access main = main or sys.modules['__main__'].main # Call the main function, passing through any arguments # to the final program. sys.exit(main(sys.argv[:1] + flags_passthrough)) 

让我们一行一行地打破:

 flags_passthrough = f._parse_flags(args=args) 

这确保了你通过命令行parsing的参数是有效的,例如python my_model.py --data_dir='...' --max_iteration=10000 ,这个特性是基于python标准argparse模型实现的。

 main = main or sys.modules['__main__'].main 

=右边的第一个main参数是当前函数run(main=None, argv=None)的第一个参数。 sys.modules['__main__']表示当前正在运行的文件(例如my_model.py )。

所以有两种情况:

  1. 您在my_model.py没有main函数然后您必须调用tf.app.run(my_main_running_function)

  2. 你在my_model.py有一个mainfunction。 (这是最重要的。)

最后一行:

 sys.exit(main(sys.argv[:1] + flags_passthrough)) 

确保您的main(argv)my_main_running_function(argv)函数正确地使用parsing的参数进行调用。

tf.app没有什么特别之处。 这只是一个通用的入口点脚本 ,

使用可选的“主”function和“argv”列表运行程序。

它与neural network无关,它只是调用主函数,通过任何parameter passing给它。

简单来说, tf.app.run()的工作是首先设置全局标志以备后用,如:

 from tensorflow.python.platform import flags f = flags.FLAGS 

然后使用一组参数运行您的自定义主函数。

例如在TensorFlow NMT代码库中,训练/推断程序执行的第一个入口点就是从这一点开始的(见下面的代码)

 if __name__ == "__main__": nmt_parser = argparse.ArgumentParser() add_arguments(nmt_parser) FLAGS, unparsed = nmt_parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 

在使用argparseparsing参数之后,使用tf.app.run()运行定义如下的函数“main”:

 def main(unused_argv): default_hparams = create_hparams(FLAGS) train_fn = train.train inference_fn = inference.inference run_main(FLAGS, default_hparams, train_fn, inference_fn) 

所以,在为全局使用设置标志之后, tf.app.run()简单地运行那个以argv作为parameter passing给它的argv函数。

PS:正如萨尔瓦多•达利(Salvador Dali)的回答所说,这只是一个很好的软件工程实践,我猜想,尽pipe我不确定TensorFlow是否执行了main函数的优化运行,而不是使用普通的CPython运行。