Tensorflow NaN错误?

我正在使用TensorFlow,并修改了教程示例以获取RGB图像。

这个algorithm在新的图像集合上完美地工作,直到突然(仍然收敛,通常约为92%的精确度),它与ReluGrad接收到非有限值的错误相冲突。 debugging表明,没有什么不寻常的事情发生的数字,直到非常突然,不明原因,错误抛出。 添加

print "max W vales: %g %g %g %g"%(tf.reduce_max(tf.abs(W_conv1)).eval(),tf.reduce_max(tf.abs(W_conv2)).eval(),tf.reduce_max(tf.abs(W_fc1)).eval(),tf.reduce_max(tf.abs(W_fc2)).eval()) print "max b vales: %g %g %g %g"%(tf.reduce_max(tf.abs(b_conv1)).eval(),tf.reduce_max(tf.abs(b_conv2)).eval(),tf.reduce_max(tf.abs(b_fc1)).eval(),tf.reduce_max(tf.abs(b_fc2)).eval()) 

作为每个循环的debugging代码,产生以下输出:

 Step 8600 max W vales: 0.759422 0.295087 0.344725 0.583884 max b vales: 0.110509 0.111748 0.115327 0.124324 Step 8601 max W vales: 0.75947 0.295084 0.344723 0.583893 max b vales: 0.110516 0.111753 0.115322 0.124332 Step 8602 max W vales: 0.759521 0.295101 0.34472 0.5839 max b vales: 0.110521 0.111747 0.115312 0.124365 Step 8603 max W vales: -3.40282e+38 -3.40282e+38 -3.40282e+38 -3.40282e+38 max b vales: -3.40282e+38 -3.40282e+38 -3.40282e+38 -3.40282e+38 

由于我的数值都不是很高,NaN可能发生的唯一方法是由0/0处理得不好,但由于这个教程代码没有做任何分割或类似的操作,我没有看到任何其他的解释,内部的TF代码。

我对如何处理这件事毫无头绪。 有什么build议么? 该algorithm收敛性很好,在validation集上的准确率稳步提高,在8600次迭代时达到了92.5%。

其实,原来是愚蠢的。 我发布这个,以防其他人会遇到类似的错误。

 cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) 

实际上是一个计算交叉熵的可怕方法。 在一些样本中,某些类可能在一段时间后被确定地排除,导致该样本的y_conv = 0。 这通常不是一个问题,因为你对这些不感兴趣,但是在cross_entropy写在那里的方式,对于那个特定的样本/类,它将产生0 * log(0)。 因此,NaN。

用它replace

 cross_entropy = -tf.reduce_sum(y_*tf.log(tf.clip_by_value(y_conv,1e-10,1.0))) 

解决了我所有的问题

实际上,剪切不是一个好主意,因为当达到阈值时,它将阻止梯度向后传播。 相反,我们可以在softmax输出中添加一点常数。

 cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv + 1e-10)) 

如果y_conv是softmax的结果,比如说y_conv = tf.nn.softmax(x) ,那么更好的解决办法是用log_softmaxreplace它:

 y = tf.nn.log_softmax(x) cross_entropy = -tf.reduce_sum(y_*y) 

具体的答案

 def cross_entropy(x, y, axis=-1): safe_y = tf.where(tf.equal(x, 0.), tf.ones_like(y), y) return -tf.reduce_sum(x * tf.log(safe_y), axis) def entropy(x, axis=-1): return cross_entropy(x, x, axis) 

但它工作?

 x = tf.constant([0.1, 0.2, 0., 0.7]) e = entropy(x) # ==> 0.80181855 g = tf.gradients(e, x)[0] # ==> array([1.30258512, 0.60943794, 0., -0.64332503], dtype=float32) Yay! No NaN. 

(注意:删除了dup交叉post 。)

一般食谱

使用内部tf.where来确保函数没有渐近线。 也就是说,改变inf生成函数的input,使得不能创buildinf。 然后使用第二个tf.where来始终select有效的代码path。 也就是说,像“正常”那样执行math条件,即“天真”的实现。

在Python代码中,配方是:

而不是这个:

 tf.where(x_ok, f(x), safe_f_x) 

做这个:

 safe_x = tf.where(x_ok, x, safe_x) tf.where(x_ok, f(safe_x), safe_f_x) 

假设你想计算:

 f(x) = { 1/x, x!=0 { 0, x=0 

一个幼稚的实现导致NaN在梯度上,即,

 def f(x): x_ok = tf.not_equal(x, 0.) safe_f_x = tf.zeros_like(x) return tf.where(x_ok, 1. / x, safe_f_x) 

它工作吗?

 x = tf.constant([-1., 0, 1]) tf.gradients(f(x), x)[0].eval() # ==> array([ -1., nan, -1.], dtype=float32) # ...bah! We have a NaN at the asymptote despite not having # an asymptote in the non-differentiated result. 

使用tf.where时避免NaN渐变的基本模式是两次调用tf.where 。 最里面的tf.where确保结果f(x)总是有限的。 最外面的tf.where确保select正确的结果。 对于运行的例子来说,这个技巧就是这样玩的:

 def safe_f(x): x_ok = tf.not_equal(x, 0.) safe_x = tf.where(x_ok, x, tf.ones_like(x)) safe_f_x = tf.zeros_like(x) return tf.where(x_ok, 1. / safe_x, safe_f_x) 

但它工作?

 x = tf.constant([-1., 0, 1]) tf.gradients(safe_f(x), x)[0].eval() # ==> array([-1., 0., -1.], dtype=float32) # ...yay! double-where trick worked. Notice that the gradient # is now a constant at the asymptote (as opposed to being NaN). 

您正在尝试使用标准公式计算交叉熵 。 当x=0 ,不但数值不确定,而且在数值上也不稳定。

最好使用tf.nn.softmax_cross_entropy_with_logits,或者如果你真的想使用手工制作的公式,在日志中将tf.clip_by_value零设置为非常小的数字。

下面是TensorFlow 1.1中二元(sigmoid)和分类(softmax)交叉熵损失的实现:

正如人们在二进制情况下所看到的,他们考虑了一些特殊情况来实现数值稳定性:

 # The logistic loss formula from above is # x - x * z + log(1 + exp(-x)) # For x < 0, a more numerically stable formula is # -x * z + log(1 + exp(x)) # Note that these two expressions can be combined into the following: # max(x, 0) - x * z + log(1 + exp(-abs(x))) # To allow computing gradients at zero, we define custom versions of max and # abs functions. zeros = array_ops.zeros_like(logits, dtype=logits.dtype) cond = (logits >= zeros) relu_logits = array_ops.where(cond, logits, zeros) neg_abs_logits = array_ops.where(cond, -logits, logits) return math_ops.add(relu_logits - logits * labels, math_ops.log1p(math_ops.exp(neg_abs_logits)), name=name)