本文适用于以下场景(假定
model
中含有复杂的custom loss function
):
load
已训练完成的model
,进行infer
- 模型滚动更新:加载旧模型,在旧模型的基础上进行增量训练得到新模型
运行环境
Tensorflow2.0.0
定义custom loss function
这里定义的是
pairwise
场景下binary crossentropy
的loss function
:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37def pairwise_binary_crossentropy(query):
'''
pairwise_binary_crossentropy, each pair example uses binary crossentropy as loss.
Args:
query: query id
Returns:
'''
def loss(y_true, y_pred):
pair_mask = tf.equal(query, tf.transpose(query))
pair_mask = tf.cast(pair_mask, tf.float32)
#根据pair_mask创建对角矩阵。
pair_mask_diag = tf.linalg.band_part(tf.ones_like(pair_mask), -1, 0)
pair_mask_diag = tf.linalg.band_part(pair_mask_diag, 0, 0)
pair_mask_diag_reverse = 1 - pair_mask_diag
pair_mask = pair_mask * pair_mask_diag_reverse
si_minus_sj = y_pred - tf.transpose(y_pred)
yi_minus_yj = y_true - tf.transpose(y_true)
yi_minus_yj = tf.maximum(tf.minimum(1., yi_minus_yj), -1.)
yi_minus_yj = 0.5 * (1 + yi_minus_yj)
logloss = tf.nn.sigmoid_cross_entropy_with_logits(labels=yi_minus_yj, logits=si_minus_sj)
num_pairs = tf.reduce_sum(pair_mask)
loss = pair_mask * logloss
loss = tf.reduce_sum(loss)
res = loss / (num_pairs + 0.00001)
return res
return loss可以看出,这里不仅仅是
loss_function(y_true, y_pred)
形式了,而是最外层还有query
这个输入。
解决方案
对于利用
tf.keras.models.load_model()
加载模型,然后直接输入测试数据进行infer
的场景此时若不调用
model.compile()
进行编译的话,是会报错的,无法得到prediction
。而将custom loss function
传入model.compile(loss=)
中loss
参数的话,报的错更是千奇百怪。我们只需要给loss
传入任意一个 keras 提供的loss function
,如binary_crossentropy
即可。如model.compile(loss=tensorflow.keras.losses.binary_crossentropy)
。因为 infer 时就是 network 的前向传播计算过程,与 loss 的具体形式无关,只与网络结构与权重有关。所以,即使随机指定 loss 形式,只要
model.compile()
通过即可。经实测,不同的loss function
形式确实不会影响最终的prediction
结果。不放心的小伙伴可以自己尝试。对于模型滚动更新场景
按常理说,模型滚动更新的步骤应该是这样的:1.加载模型,得到基准模型对象
model
;2.读取训练数据,利用基准模型对象,调用model.fit()
接口,进行模型增量训练。理想很美好,现实很骨感。上述做法中,必须先调用
model.compile()
指定loss
以及optimizer
后,才可以正常的调用model.fit()
,执行后续的训练流程。好,现在又绕回来了。只要调用
model.compile()
,就逃不过model.compile(loss=xxx)
中custom loss function
如何正确传入的问题。这里与infer场景不同,是无法随意传入loss function
让其compile通过即可,而是一定要指定自定义的custom loss function
,因为模型滚动更新时训练需要利用自定义的custom loss function
训练更新权重。「Restore 模型后得到基准模型对象,并在此基础上调用fit接口进行训练」这条路已经走到死胡同了,那我们不妨退出来再看看有什么可以“曲线救国”的方法。
明确一点:模型更新的实质是什么?模型更新的实质是参数的更新。那么模型滚动更新的实质是什么?其实质是在现有模型的参数基础上增量更新。
所以,我们只要从0构建与基准模型拥有完全相同结构的网络,并将基准模型的权重赋值给新构建的网络,再调用fit接口就可以实现模型增量更新的功能。