问题
I'm trying to create an activation function in Keras that can take in a parameter beta
like so:
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation
class Swish(Activation):
def __init__(self, activation, beta, **kwargs):
super(Swish, self).__init__(activation, **kwargs)
self.__name__ = 'swish'
self.beta = beta
def swish(x):
return (K.sigmoid(beta*x) * x)
get_custom_objects().update({'swish': Swish(swish, beta=1.)})
It runs fine without the beta
parameter, but how can I include the parameter in the activation definition? I also want this value to be saved when I do model.to_json()
like for ELU activation.
Update: I wrote the following code based on @today's answer:
from keras.layers import Layer
from keras import backend as K
class Swish(Layer):
def __init__(self, beta, **kwargs):
super(Swish, self).__init__(**kwargs)
self.beta = K.cast_to_floatx(beta)
self.__name__ = 'swish'
def call(self, inputs):
return K.sigmoid(self.beta * inputs) * inputs
def get_config(self):
config = {'beta': float(self.beta)}
base_config = super(Swish, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
from keras.utils.generic_utils import get_custom_objects
get_custom_objects().update({'swish': Swish(beta=1.)})
gnn = keras.models.load_model("Model.h5")
arch = gnn.to_json()
with open(directory + 'architecture.json', 'w') as arch_file:
arch_file.write(arch)
However, it does not currently save the beta
value in the .json file. How can I make it save the value?
回答1:
Since you want to save the parameters of activation function when serializing the model, I think it is better to define the activation function as a layer like the advanced activations which have been defined in Keras. You can do it like this:
from keras.layers import Layer
from keras import backend as K
class Swish(Layer):
def __init__(self, beta, **kwargs):
super(Swish, self).__init__(**kwargs)
self.beta = K.cast_to_floatx(beta)
def call(self, inputs):
return K.sigmoid(self.beta * inputs) * inputs
def get_config(self):
config = {'beta': float(self.beta)}
base_config = super(Swish, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
Then you can use it the same way you use a Keras layer:
# ...
model.add(Swish(beta=0.3))
Since get_config()
method has been implemented in its definition, the parameter beta
would be saved when using methods like to_json()
or save()
.
来源:https://stackoverflow.com/questions/53050448/custom-activation-with-parameter