TianduAI

Action speak louder than words

0%

tf.Print()用法

环境:tensorflow 1.14.0

在模型训练的过程中,经常会有实时观察中间tensor值的需求。Python自带的print函数只能打印出tensor或者op的名字、属性信息等等,所以需另辟蹊径。

实际上,在模型训练时,打印出某些tensor的中间值,有两种实现方法:

  • 运行sess.run(tensor),可以得到tensor的值
  • 采用tf.Print(tensor)函数,直接打印出tensor的当前值

第一种方法的典型场景是间隔性的获取结果类型tensor。比如,在训练过程中每迭代100次打印出当前的loss值以便观察。

当然,如果tensor的值不依赖于placeholder的计算,比如某一层的权重tensor-W,调用sess.run(tensor)时是不需要指定feed_dict的。这样也可以实现在训练过程中打印出tensor值的目标。

第二种方法更加具有普适性,可以打印出任意tensor在当前迭代时的值。不过,在使用tf.Print()函数时,需要用静态图的视角去考虑。

tf.Print()其实是静态图中的一个op节点。如果在运行一个op时,数据流没有流经该tf.Print op节点,则tf.Print()不会生效。如果加入了tf.Print却没有按预期打印出想要的结果,建议仔细检查下静态图逻辑。

数据流入tf.Print节点后,会原样返回。所以tf.Print可以看做恒等输出op,只是打印出了流入数据的值而已。

1
2
3
4
5
6
7
8
tf.Print(
input_,
data,
message=None,
first_n=None,
summarize=None,
name=None
)

input_为输入的tensor,会原样返回;
data为需要打印的tensor
message为日志输出时的前缀信息;