TianduAI

Action speak louder than words

0%

Tensorflow自定义Layer多输入input_shape

环境:tensorflow 2.0.0

Tensorflow中,继承Layer实现自定义类,一般重写build()方法,且build()方法需接收input_shape这一个参数。

1
2
3
class customLayer(Layer):
def build(self, input_shape):
xxx

若该层只有一个输入,那么input_shape就是该输入的shape

若该层有多个输入,那么input_shape为每个输入shapelist。需要注意的是,这时候call()函数中的多个输入需要以一个list的形式传入

下面展示正确用法与错误用法,注释中详细介绍了细节:

  • 正确用法

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    class myLayer(Layer):
    def __init__(self, *args, **kwargs):
    super(myLayer, self).__init__(*args, **kwargs)

    def build(self, input_shape):
    print(input_shape) # 输出结果为:[TensorShape([None, 5]), TensorShape([None, 3])]
    shape1 = input_shape[0][1]
    shape2 = input_shape[1][1]

    self.w1 = self.add_weight(shape=(shape1, 1))
    self.w2 = self.add_weight(shape=(shape2, 1))

    def call(self, x): # 这里多输入需要将x以list的形式传入
    return tf.add(tf.matmul(x[0], self.w1), tf.matmul(x[1], self.w2))

    inputs1 = Input((5))
    inputs2 = Input((3))

    outputs = myLayer()([inputs1, inputs2])

    model = Model(inputs=[inputs1, inputs2], outputs=outputs)
    model.compile(loss=tf.keras.losses.binary_crossentropy)
  • 错误用法

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    class myLayer(Layer):
    def __init__(self, *args, **kwargs):
    super(myLayer, self).__init__(*args, **kwargs)

    def build(self, input_shape):
    print(input_shape) # 输出结果为(None, 6),因为call()函数中分开传入了多个输入
    shape1 = input_shape[0][1]
    shape2 = input_shape[1][1]

    self.w1 = self.add_weight(shape=(shape1, 1))
    self.w2 = self.add_weight(shape=(shape2, 1))

    def call(self, x1, x2): # 这里分别输入,得不到正确结果
    return tf.add(tf.matmul(x1, self.w1), tf.matmul(x2, self.w2))

    inputs1 = Input((6))
    inputs2 = Input((3))

    outputs = myLayer()(inputs1, inputs2)

    model = Model(inputs=[inputs1, inputs2], outputs=outputs)
    model.compile(loss=tf.keras.losses.binary_crossentropy)