In the documentation of Jax, the jax.lax.scan function is a way to write functionally a for-loop that can support custom jvp.
However, it is very common that the argu