I cannot find a way to set the initial weights of the neural network, could someone tell me how please? I am using python package sklearn.neural_network.MLPClassifier.
Solution: A working solution is to inherit from MLPClassifier and override the _init_coef method. In the _init_coef write the code to set the initial weights. Then use the new class "MLPClassifierOverride" as in the example below instead of "MLPClassifier"
# new class
class MLPClassifierOverride(MLPClassifier):
# Overriding _init_coef method
def _init_coef(self, fan_in, fan_out):
if self.activation == 'logistic':
init_bound = np.sqrt(2. / (fan_in + fan_out))
elif self.activation in ('identity', 'tanh', 'relu'):
init_bound = np.sqrt(6. / (fan_in + fan_out))
else:
raise ValueError("Unknown activation function %s" %
self.activation)
coef_init = ### place your initial values for coef_init here
intercept_init = ### place your initial values for intercept_init here
return coef_init, intercept_init
multilayer_perceptron.py
initializes the weights based on the nonlinear function used for hidden layers. If you want to try a different initialization, you can take a look at the function _init_coef
here and modify as you desire.
The docs show you the attributes in use.
Attributes:
...
coefs_
: list, length n_layers - 1 The ith element in the list represents the weight matrix corresponding to > layer i.
intercepts_
: list, length n_layers - 1 The ith element in the list represents the bias vector corresponding to layer > i + 1.
Just build your classifier clf=MLPClassifier(solver="sgd")
and set coefs_
and intercepts_
before calling clf.fit()
.
The only remaining question is: does sklearn overwrite your inits?
The code looks like:
if not hasattr(self, 'coefs_') or (not self.warm_start and not
incremental):
# First time training the model
self._initialize(y, layer_units)
This looks to me like it won't replace your given coefs_
(you might check biases too).
The packing and unpacking functions further indicates that this should be possible. These are probably used for serialization through pickle internally.