I am trying to implement entmax-alpha as is described in here.
Here is the code.
import jax import jax.numpy as jnp from jax import custom_jvp from jax