I am using Jax to do some machine learning jobs. Jax uses XLA to do some just-in-time compile for acceleration but the compile itself is too slow on CPU. My situation is that th