How to initialize the model with certain weights?

让人想犯罪 __ 提交于 2021-01-28 11:22:27

问题


I am using the example "stateful_clients" in tensorflow-federated examples. I want to use my pretrained model weights to initialize the model. I use the function model.load_weights(init_weight). But it seems that it doesn't work. The validation accuracy in the first round is still low. How can I solve the problem?

def tff_model_fn():
    """Constructs a fully initialized model for use in federated averaging."""
    keras_model = get_five_layers_cnn([28, 28, 1])
    keras_model.load_weights(init_weight)
    loss = tf.keras.losses.SparseCategoricalCrossentropy()
    return stateful_fedavg_tf.KerasModelWrapper(keras_model,
                                                test_data.element_spec, loss)

回答1:


A quick primer on state and model weights in TFF

TFF takes a distinct perspective on state in machine learning, generally a consequence of its desire to be purely functional.

Usually in machine learning, a model is conceptually a function which takes data and produces a prediction. However, this notion is a little overloaded at times; does 'model' refer to a trained model (fitting the specification above), or an architecture which is parameterized by its parameters, and therefore needs to accept these parameters as an argument to be considered truly a 'function'? A conception somewhat in the middle is that of a 'stateful function', which I think tends to be what people intend to refer to when they use the term 'model'.

TFF standardizes on the latter understanding. For TFF, a 'model' is a function which accepts parameters along with data as an argument, producing a prediction. This is generally to avoid the notion of a stateful function, which is disallowed by a purely functional perspective (f(x) == f(x) should always be true, so f cannot have any state which affects its output).

On the code in question

I'm not super familiar with this portion of the TFF codebase; in particular I'm a little surprised at the behavior of the keras model wrapper, as usually TFF wants to serialize all logic into TFF-defined data structures as soon as possible (at least, this is how I think about it). Glancing at the code, it looks to me like it could work--but there have been exciting interactions between TFF and Keras in the past.

Briefly, here is how this path should be working:

  1. The model function you define above is invoked while building the initialize computation; the logic to load weights or the weights themselves would hopefully be serialized into the graph that TFF generates to represent initialize.
  2. Upon calling iterative_process.initialize, you would find your desired weights populated in the appropriate attributes of the returned data structure. This would serve as your initial starting point for your iterative process, and you would be off to the races.

What I am suspicious of in the above is 1. If keras_model.load_weights does not get serialized as TF logic (knowing a bit about Keras, I would assume it probably does not--it is likely to be pure Python logic), this will not be run when you call iterative_process.initialize. Therefore the state returned from initialize won't have your specified weights. TFF is designed for deployment from the ground up, to devices which do not have a python interpreter--TFF uses TensorFlow effectively as a serialized representation of logic, and logic which cannot be captured by TF can't be run.

If this suspicion is true, the appropriate solution is to run this to run the weight loading logic directly in Python. TFF provides some utilities to help with this kind of thing, like tff.learning.state_with_new_model_weights. This would be used like:

state = iterative_process.initialize()
weights = tf.keras.load_weights(...)  # No idea if this call is correct, probably not.
state_with_loaded_weights = tff.learning.state_with_new_model_weights(state, weights)
...
# continue on using state in the iterative process


来源:https://stackoverflow.com/questions/65273151/how-to-initialize-the-model-with-certain-weights

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!