TFF loading a pre-trained Keras model

独自空忆成欢 提交于 2020-07-09 14:26:46

问题


My goal is to load a base model from a .hdf5 file (it's a Keras model), and continue to train it with federated learning. Here is how I initialize the base model for FL:

def model_fn():
    model = tf.keras.load_model(path/to/model.hdf5)
    return tff.learning.from_keras_model(model=model, 
                                         dummy_batch=db, 
                                         loss=loss, 
                                         metrics=metrics)

trainer = tff.learning.build_federated_averaging_process(model_fn)
state = trainer.initialize()

However, it seems like the resulting state.model weights are randomly initialized, and are different from my saved model. When I evaluate the model's performance even before any federated training, it performs as a randomly initialized model: 50% accuracy. Here's how I evaluate the performance:

def evaluate(state):
    keras_model = tf.keras.models.load_model(path/to/model.hdf5, compile=False)
    tff.learning.assign_weights_to_keras_model(keras_model, state.model)
    keras_model.compile(loss=loss, metrics=metrics)
    return keras_model.evaluate(features, values)

How can I initialize a tff model with the saved model weights?


回答1:


Yes, I think it is expected that initialize would rerun the initializers, and return this value.

However, there is a way to do such a thing with TFF. TFF is strongly typed and functional--if we can construct an argument with the correct values which matches the type expected by your federated averaging process above, things should "just work". So the goal here will be to construct an argument satisfying these requirements.

You can look to the FileCheckpointManager's load implementation for a little inspiration here, but I think you are in a simpler case with Keras.

Assuming you have your hands on state like above and model your Keras model, there is a shortcut to unpacking and repacking everything here--as indicated in this section of one of TFF's tutorials--that is, the usage of tff.learning.state_with_new_model_weights. If you have state and model as above (and TF is in eager mode), the following should work for you:

state = tff.learning.state_with_new_model_weights(
    state,
    trainable_weights=[v.numpy() for v in model.trainable_weights],
    non_trainable_weights=[
        v.numpy() for v in model.non_trainable_weights
    ])

This should reassign your model's weights to the appropriate elements of the state object.



来源:https://stackoverflow.com/questions/61786305/tff-loading-a-pre-trained-keras-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!