环境:tensorflow 2.0.0
Tensorflow
中,继承Layer
实现自定义类,一般重写build()
方法,且build()
方法需接收input_shape
这一个参数。
1 | class customLayer(Layer): |
若该层只有一个输入,那么input_shape
就是该输入的shape
。
若该层有多个输入,那么input_shape
为每个输入shape
的list
。需要注意的是,这时候call()
函数中的多个输入需要以一个list
的形式传入。
下面展示正确用法与错误用法,注释中详细介绍了细节:
正确用法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22class myLayer(Layer):
def __init__(self, *args, **kwargs):
super(myLayer, self).__init__(*args, **kwargs)
def build(self, input_shape):
print(input_shape) # 输出结果为:[TensorShape([None, 5]), TensorShape([None, 3])]
shape1 = input_shape[0][1]
shape2 = input_shape[1][1]
self.w1 = self.add_weight(shape=(shape1, 1))
self.w2 = self.add_weight(shape=(shape2, 1))
def call(self, x): # 这里多输入需要将x以list的形式传入
return tf.add(tf.matmul(x[0], self.w1), tf.matmul(x[1], self.w2))
inputs1 = Input((5))
inputs2 = Input((3))
outputs = myLayer()([inputs1, inputs2])
model = Model(inputs=[inputs1, inputs2], outputs=outputs)
model.compile(loss=tf.keras.losses.binary_crossentropy)错误用法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22class myLayer(Layer):
def __init__(self, *args, **kwargs):
super(myLayer, self).__init__(*args, **kwargs)
def build(self, input_shape):
print(input_shape) # 输出结果为(None, 6),因为call()函数中分开传入了多个输入
shape1 = input_shape[0][1]
shape2 = input_shape[1][1]
self.w1 = self.add_weight(shape=(shape1, 1))
self.w2 = self.add_weight(shape=(shape2, 1))
def call(self, x1, x2): # 这里分别输入,得不到正确结果
return tf.add(tf.matmul(x1, self.w1), tf.matmul(x2, self.w2))
inputs1 = Input((6))
inputs2 = Input((3))
outputs = myLayer()(inputs1, inputs2)
model = Model(inputs=[inputs1, inputs2], outputs=outputs)
model.compile(loss=tf.keras.losses.binary_crossentropy)