Batch Normalization doesn't have gradient in tensorflow 2.0?

依然范特西╮ 提交于 2019-12-04 12:42:00

The problem is here:

gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)

You should only be getting gradients for the trainable variables. So you should change it to

gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)

The same goes for the three lines following. The variables field includes stuff like the running averages batch norm uses during inference. Because they are not used during training, there are no sensible gradients defined and trying to compute them will lead to a crash.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!