How to plot Histogram summary for delta weight in Federated Tensorflow?

泄露秘密 提交于 2020-04-16 05:44:34

问题


I am analyzing a method that I have implemented in Tensorflow Federated with FedAvg. I need to have a histogram for every clients' delta weights that are communicated to the server. Each client separately called in simulation/federated_avaraging.py, but the thing is I can not call the following API in there. tf.summary.histogram(). any help would be appreciated.


回答1:


In TFF, TensorFlow represents "local computation"; so if you need a way to inspect something across clients, you will need to first aggregate the values you want via TFF, or inspect the returned values in native python.

If you want to use TF ops, I would recommend using the tff.federated_collect intrinsic, to "gather" all the values you want on the server, then federated_map a TF function which takes these values and produces your desired visualization.

If you would rather work at the Python level, there is an easy option here (this is the approach I would take): simply return the results of training at the clients from your tff.federated_computation; when you invoke this computation, this will materialize a Python list of these results, and you can visualize it however you want. This would be roughly along the lines of something like:

@tff.federated_computation(...)
def train_one_round(...):
  ...
  trained_clients = run_training(...)
  new_model = update_global_model(trained_clients,...)
  return new_model, trained_clients

In this example, this function will return a tuple, the second element of which is a Python list representing the results of training at all clients.



来源:https://stackoverflow.com/questions/60285187/how-to-plot-histogram-summary-for-delta-weight-in-federated-tensorflow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!