How to handle JAX reshape with JIT
问题 I am trying to implement entmax-alpha as is described in here. Here is the code. import jax import jax.numpy as jnp from jax import custom_jvp from jax import jit from jax import lax from jax import vmap @jax.partial(jit, static_argnums=(2,)) def p_tau(z, tau, alpha=1.5): return jnp.clip((alpha - 1) * z - tau, a_min=0) ** (1 / (alpha - 1)) @jit def get_tau(tau, tau_max, tau_min, z_value): return lax.cond(z_value < 1, lambda _: (tau, tau_min), lambda _: (tau_max, tau), operand=None ) @jit def