问题
I am analyzing a method that I have implemented in Tensorflow Federated with FedAvg. I need to have a histogram for every clients' delta weights that are communicated to the server. Each client separately called in simulation/federated_avaraging.py
, but the thing is I can not call the following API in there. tf.summary.histogram()
. any help would be appreciated.
回答1:
In TFF, TensorFlow represents "local computation"; so if you need a way to inspect something across clients, you will need to first aggregate the values you want via TFF, or inspect the returned values in native python.
If you want to use TF ops, I would recommend using the tff.federated_collect
intrinsic, to "gather" all the values you want on the server, then federated_map
a TF function which takes these values and produces your desired visualization.
If you would rather work at the Python level, there is an easy option here (this is the approach I would take): simply return the results of training at the clients from your tff.federated_computation
; when you invoke this computation, this will materialize a Python list of these results, and you can visualize it however you want. This would be roughly along the lines of something like:
@tff.federated_computation(...)
def train_one_round(...):
...
trained_clients = run_training(...)
new_model = update_global_model(trained_clients,...)
return new_model, trained_clients
In this example, this function will return a tuple, the second element of which is a Python list representing the results of training at all clients.
来源:https://stackoverflow.com/questions/60285187/how-to-plot-histogram-summary-for-delta-weight-in-federated-tensorflow