环境:tensorflow 1.14.0
在模型训练的过程中,经常会有实时观察中间tensor
值的需求。Python
自带的print
函数只能打印出tensor
或者op
的名字、属性信息等等,所以需另辟蹊径。
实际上,在模型训练时,打印出某些tensor
的中间值,有两种实现方法:
- 运行
sess.run(tensor)
,可以得到tensor
的值 - 采用
tf.Print(tensor)
函数,直接打印出tensor
的当前值
第一种方法的典型场景是间隔性的获取结果类型tensor
值。比如,在训练过程中每迭代100次打印出当前的loss
值以便观察。
当然,如果tensor
的值不依赖于placeholder
的计算,比如某一层的权重tensor-W
,调用sess.run(tensor)
时是不需要指定feed_dict
的。这样也可以实现在训练过程中打印出tensor
值的目标。
第二种方法更加具有普适性,可以打印出任意tensor
在当前迭代时的值。不过,在使用tf.Print()
函数时,需要用静态图的视角去考虑。
tf.Print()
其实是静态图中的一个op
节点。如果在运行一个op时,数据流没有流经该tf.Print op
节点,则tf.Print()
不会生效。如果加入了tf.Print
却没有按预期打印出想要的结果,建议仔细检查下静态图逻辑。
数据流入tf.Print
节点后,会原样返回。所以tf.Print
可以看做恒等输出op
,只是打印出了流入数据的值而已。
1 | tf.Print( |
input_
为输入的tensor
,会原样返回;data
为需要打印的tensor
;message
为日志输出时的前缀信息;