I am using JAX to implement a simple neural network (NN) and I want to access and save the gradients from the backward pass for further analysis after the NN ran. I can acce