“冻结”一些variables/范围在tensorflow:stop_gradient与传递variables最小化

我试图实施敌对neural network ,它需要在交替训练minibatches期间“冻结”图中的一个或另一个部分。 即有两个子networking:G和D.

G( Z ) -> Xz D( X ) -> Y 

G损失函数依赖于D[G(Z)], D[X]

首先,我需要在所有G参数固定的情况下对D中的参数进行训练,然后使用D中的参数固定G中的参数。 第一种情况下的损失函数将是第二种情况下的负损失函数,并且更新将必须应用于第一或第二子networking的参数。

我看到tensorflow有tf.stop_gradient函数。 为了训练D(下游)子networking,我可以使用这个function来阻止梯度stream向

  Z -> [ G ] -> tf.stop_gradient(Xz) -> [ D ] -> Y 

tf.stop_gradient非常简洁,没有内联示例(例如seq2seq.py太长而且不容易阅读),但看起来像在图创build期间必须调用它。 这是否意味着,如果我想要交替批量阻止/取消阻止梯度stream,我需要重新创build并重新初始化图模型?

此外,似乎tf.stop_gradient阻止stream经G(上游)networking的tf.stop_gradient ,对吧?

作为替代scheme,我看到可以将优化器调用的variables列表作为opt_op = opt.minimize(cost, <list of variables>)如果可以获得每个variables的范围中的所有variables子网。 一个人可以得到一个<list of variables>为一个tf.scope?

在你的问题中提到的最简单的方法就是创build两个优化器操作,分别调用opt.minimize(cost, ...) 。 默认情况下,优化器将使用tf.trainable_variables()所有variables。 如果要将variables过滤到特定范围,可以使用可选的scope参数tf.get_collection() ,如下所示:

 optimizer = tf.train.AdagradOptimzer(0.01) first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "scope/prefix/for/first/vars") first_train_op = optimizer.minimize(cost, var_list=first_train_vars) second_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "scope/prefix/for/second/vars") second_train_op = optimizer.minimize(cost, var_list=second_train_vars) 

你可能想要考虑的另一个select是你可以在一个variables上设置trainable = False。 这意味着它不会被训练修改。

 tf.Variable(my_weights, trainable=False) 

@ mrry的回答是完全正确的,也许比我想提出的更一般化。 但我认为一个简单的方法来完成它只是将python引用直接传递给var_list

 W = tf.Variable(...) C = tf.Variable(...) Y_est = tf.matmul(W,C) loss = tf.reduce_sum((data-Y_est)**2) optimizer = tf.train.AdamOptimizer(0.001) # You can pass the python object directly train_W = optimizer.minimize(loss, var_list=[W]) train_C = optimizer.minimize(loss, var_list=[C]) 

我在这里有一个独立的例子: https : //gist.github.com/ahwillia/8cedc710352eb919b684d8848bc2df3a

我不知道我的方法是否有缺陷,但是我用这个结构解决了这个问题:

 do_gradient = <Tensor that evaluates to 0 or 1> no_gradient = 1 - do_gradient wrapped_op = do_gradient * original + no_gradient * tf.stop_gradient(original) 

所以,如果do_gradient = 1 ,那么值和渐变将会stream过,但是如果do_gradient = 0 ,那么值将只stream过stop_gradient op,这将停止梯度回stream。

对于我的场景,将do_gradient挂钩到random_shuffle张量的索引让我随机训练不同的networking片段。