I am trying to write my own keras layer. In this layer, I want to use some other keras layers. Is there any way to do something like this:
class MyDenseLayer(tf.
It's much more comfortable and concise to put existing layers in the tf.keras.models.Model class. If you define non-custom layers such as layers, conv2d, the parameters of those layers are not trainable by default.
class MyDenseLayer(tf.keras.Model):
def __init__(self, num_outputs):
super(MyDenseLayer, self).__init__()
self.num_outputs = num_outputs
self.fc = tf.keras.layers.Dense(num_outputs)
def call(self, input):
return self.fc(input)
def compute_output_shape(self, input_shape):
shape = tf.TensorShape(input_shape).as_list()
shape[-1] = self.num_outputs
return tf.TensorShape(shape)
layer = MyDenseLayer(10)
Check this tutorial: https://www.tensorflow.org/guide/keras#model_subclassing
This works for me and is clean, concise, and readable.
import tensorflow as tf
class MyDense(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(MyDense, self).__init__(kwargs)
self.dense = tf.keras.layers.Dense(2, tf.keras.activations.relu)
def call(self, inputs, training=None):
return self.dense(inputs)
inputs = tf.keras.Input(shape=10)
outputs = MyDense(trainable=True)(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name='test')
model.compile(loss=tf.keras.losses.MeanSquaredError())
model.summary()
Output:
Model: "test"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 10)] 0
_________________________________________________________________
my_dense (MyDense) (None, 2) 22
=================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________
Note that trainable=True
is needed. I have posted a questions about it here.
If you look at the documentation for how to add custom layers, they recommend that you use the .add_weight(...)
method. This method internally places all weights in self._trainable_weights
. So to do what you want, you mush first define the keras layers you want to use, build them, copy the weights and then build your own layer. If I update your code it should be something like
class mylayer(tf.keras.layers.Layer):
def __init__(self, num_outputs, num_outputs2):
self.num_outputs = num_outputs
super(mylayer, self).__init__()
def build(self, input_shape):
self.fc = tf.keras.layers.Dense(self.num_outputs)
self.fc.build(input_shape)
self._trainable_weights = self.fc.trainable_weights
super(mylayer, self).build(input_shape)
def call(self, input):
return self.fc(input)
layer = mylayer(10)
input = tf.keras.layers.Input(shape=(16, ))
output = layer(input)
model = tf.keras.Model(inputs=[input], outputs=[output])
model.summary()
You should then get what you want
In the TF2 custom layer Guide, they "recommend creating such sublayers in the __init__
method (since the sublayers will typically have a build
method, they will be built when the outer layer gets built)." So just move the creation of self.fc
into __init__
will give what you want.
class MyDenseLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs):
super(MyDenseLayer, self).__init__()
self.num_outputs = num_outputs
self.fc = tf.keras.layers.Dense(self.num_outputs)
def build(self, input_shape):
self.built = True
def call(self, input):
return self.fc(input)
input = tf.keras.layers.Input(shape = (16,))
output = MyDenseLayer(10)(input)
model = tf.keras.Model(inputs = [input], outputs = [output])
model.summary()
Output:
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 16)] 0
_________________________________________________________________
my_dense_layer_2 (MyDenseLay (None, 10) 170
=================================================================
Total params: 170
Trainable params: 170
Non-trainable params: 0