how to print local outputs in tensorflow federated?

…衆ロ難τιáo~ 提交于 2019-12-02 05:47:10

If you only want a list of the values that go into the aggregations (e.g. into tff.federated_mean), one option would be to add additional outputs to aggregate_mnist_metrics_across_clients() to include metrics computed using tff.federated_collect().

This might look something like:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return {
      'num_examples': tff.federated_sum(metrics.num_examples),
      'loss': tff.federated_mean(metrics.loss, metrics.num_examples),
      'accuracy': tff.federated_mean(metrics.accuracy, metrics.num_examples),
      'per_client/num_examples': tff.federated_collect(metrics.num_examples),
      'per_client/loss': tff.federated_collect(metrics.loss),
      'per_client/accuracy': tff.federated_collect(metrics.accuracy),
  }

Which will get printed a few cells later when the computation runs:

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))

round  1, metrics=<...,per_client/accuracy=[0.14516129, 0.10642202, 0.13972603],per_client/loss=[3.2409852, 3.417463, 2.9516447],per_client/num_examples=[930.0, 1090.0, 730.0]>

Note however: if you want to know the value of a specific client, there is intentionally no way to do that. By design, TFF's language intentionally avoids a notion of client identity; there is desire to avoid making clients addressable.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!