问题
My goal is to load a base model from a .hdf5 file (it's a Keras model), and continue to train it with federated learning. Here is how I initialize the base model for FL:
def model_fn():
model = tf.keras.load_model(path/to/model.hdf5)
return tff.learning.from_keras_model(model=model,
dummy_batch=db,
loss=loss,
metrics=metrics)
trainer = tff.learning.build_federated_averaging_process(model_fn)
state = trainer.initialize()
However, it seems like the resulting state.model weights are randomly initialized, and are different from my saved model. When I evaluate the model's performance even before any federated training, it performs as a randomly initialized model: 50% accuracy. Here's how I evaluate the performance:
def evaluate(state):
keras_model = tf.keras.models.load_model(path/to/model.hdf5, compile=False)
tff.learning.assign_weights_to_keras_model(keras_model, state.model)
keras_model.compile(loss=loss, metrics=metrics)
return keras_model.evaluate(features, values)
How can I initialize a tff model with the saved model weights?
回答1:
Yes, I think it is expected that initialize
would rerun the initializers, and return this value.
However, there is a way to do such a thing with TFF. TFF is strongly typed and functional--if we can construct an argument with the correct values which matches the type expected by your federated averaging process above, things should "just work". So the goal here will be to construct an argument satisfying these requirements.
You can look to the FileCheckpointManager's load implementation for a little inspiration here, but I think you are in a simpler case with Keras.
Assuming you have your hands on state
like above and model
your Keras model, there is a shortcut to unpacking and repacking everything here--as indicated in this section of one of TFF's tutorials--that is, the usage of tff.learning.state_with_new_model_weights. If you have state and model as above (and TF is in eager mode), the following should work for you:
state = tff.learning.state_with_new_model_weights(
state,
trainable_weights=[v.numpy() for v in model.trainable_weights],
non_trainable_weights=[
v.numpy() for v in model.non_trainable_weights
])
This should reassign your model's weights to the appropriate elements of the state
object.
来源:https://stackoverflow.com/questions/61786305/tff-loading-a-pre-trained-keras-model