TianduAI

Action speak louder than words

0%

tf.keras如何解决load model时custom loss function的各种报错

本文适用于以下场景(假定model中含有复杂的custom loss function):

  • load已训练完成的model,进行infer
  • 模型滚动更新:加载旧模型,在旧模型的基础上进行增量训练得到新模型

运行环境

  • Tensorflow2.0.0

定义custom loss function

  • 这里定义的是pairwise场景下binary crossentropyloss function

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    def pairwise_binary_crossentropy(query):
    '''
    pairwise_binary_crossentropy, each pair example uses binary crossentropy as loss.
    Args:
    query: query id

    Returns:

    '''
    def loss(y_true, y_pred):
    pair_mask = tf.equal(query, tf.transpose(query))
    pair_mask = tf.cast(pair_mask, tf.float32)

    #根据pair_mask创建对角矩阵。
    pair_mask_diag = tf.linalg.band_part(tf.ones_like(pair_mask), -1, 0)
    pair_mask_diag = tf.linalg.band_part(pair_mask_diag, 0, 0)
    pair_mask_diag_reverse = 1 - pair_mask_diag

    pair_mask = pair_mask * pair_mask_diag_reverse

    si_minus_sj = y_pred - tf.transpose(y_pred)

    yi_minus_yj = y_true - tf.transpose(y_true)
    yi_minus_yj = tf.maximum(tf.minimum(1., yi_minus_yj), -1.)

    yi_minus_yj = 0.5 * (1 + yi_minus_yj)

    logloss = tf.nn.sigmoid_cross_entropy_with_logits(labels=yi_minus_yj, logits=si_minus_sj)

    num_pairs = tf.reduce_sum(pair_mask)

    loss = pair_mask * logloss
    loss = tf.reduce_sum(loss)
    res = loss / (num_pairs + 0.00001)

    return res
    return loss

    可以看出,这里不仅仅是loss_function(y_true, y_pred)形式了,而是最外层还有query这个输入。

解决方案

  • 对于利用tf.keras.models.load_model()加载模型,然后直接输入测试数据进行infer的场景

    此时若不调用model.compile()进行编译的话,是会报错的,无法得到prediction。而将custom loss function传入model.compile(loss=)loss参数的话,报的错更是千奇百怪。我们只需要给loss传入任意一个 keras 提供的 loss function,如 binary_crossentropy 即可。如model.compile(loss=tensorflow.keras.losses.binary_crossentropy)

    因为 infer 时就是 network 的前向传播计算过程,与 loss 的具体形式无关,只与网络结构与权重有关。所以,即使随机指定 loss 形式,只要model.compile()通过即可。经实测,不同的loss function 形式确实不会影响最终的prediction结果。不放心的小伙伴可以自己尝试。

  • 对于模型滚动更新场景

    按常理说,模型滚动更新的步骤应该是这样的:1.加载模型,得到基准模型对象model;2.读取训练数据,利用基准模型对象,调用model.fit()接口,进行模型增量训练。

    理想很美好,现实很骨感。上述做法中,必须先调用model.compile()指定loss以及optimizer后,才可以正常的调用model.fit(),执行后续的训练流程。

    好,现在又绕回来了。只要调用model.compile(),就逃不过model.compile(loss=xxx)custom loss function如何正确传入的问题。这里与infer场景不同,是无法随意传入loss function让其compile通过即可,而是一定要指定自定义的custom loss function,因为模型滚动更新时训练需要利用自定义的custom loss function训练更新权重。

    「Restore 模型后得到基准模型对象,并在此基础上调用fit接口进行训练」这条路已经走到死胡同了,那我们不妨退出来再看看有什么可以“曲线救国”的方法。

    明确一点:模型更新的实质是什么?模型更新的实质是参数的更新。那么模型滚动更新的实质是什么?其实质是在现有模型的参数基础上增量更新

    所以,我们只要从0构建与基准模型拥有完全相同结构的网络,并将基准模型的权重赋值给新构建的网络,再调用fit接口就可以实现模型增量更新的功能。