问题
I am trying to implement entmax-alpha as is described in here.
Here is the code.
import jax
import jax.numpy as jnp
from jax import custom_jvp
from jax import jit
from jax import lax
from jax import vmap
@jax.partial(jit, static_argnums=(2,))
def p_tau(z, tau, alpha=1.5):
return jnp.clip((alpha - 1) * z - tau, a_min=0) ** (1 / (alpha - 1))
@jit
def get_tau(tau, tau_max, tau_min, z_value):
return lax.cond(z_value < 1,
lambda _: (tau, tau_min),
lambda _: (tau_max, tau),
operand=None
)
@jit
def body(kwargs, x):
tau_min = kwargs['tau_min']
tau_max = kwargs['tau_max']
z = kwargs['z']
alpha = kwargs['alpha']
tau = (tau_min + tau_max) / 2
z_value = p_tau(z, tau, alpha).sum()
taus = get_tau(tau, tau_max, tau_min, z_value)
tau_max, tau_min = taus[0], taus[1]
return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None
@jax.partial(jit, static_argnums=(1, 2,))
def map_row(z_input, alpha, T):
z = (alpha - 1) * z_input
tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
length=T)
tau = (result['tau_max'] + result['tau_min']) / 2
result = p_tau(z, tau, alpha)
return result / result.sum()
@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
reduce_length = input.shape[axis]
input = jnp.swapaxes(input, -1, axis)
input = input.reshape(input.size / reduce_length, reduce_length)
result = vmap(jax.partial(map_row, alpha=alpha, T=T), 0)(input)
return jnp.swapaxes(result, -1, axis)
@jax.partial(jit, static_argnums=(1, 2,))
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
input = primals[0]
Y = entmax(input, axis, alpha, T)
gppr = Y ** (2 - alpha)
grad_output = tangents[0]
dX = grad_output * gppr
q = dX.sum(axis=axis) / gppr.sum(axis=axis)
q = jnp.expand_dims(q, axis=axis)
dX -= q * gppr
return Y, dX
@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
return _entmax_jvp_impl(axis, alpha, T, primals, tangents)
When I call it with the following code:
import numpy as np
from jax import value_and_grad
input = jnp.array(np.random.randn(64, 10))
weight = jnp.array(np.random.randn(64, 10))
def toy(input, weight):
return (weight*entmax(input, axis=-1, alpha=1.5, T=20)).sum()
value_and_grad(toy)(input, weight)
I got the following error.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-3a62e54c67d2> in <module>()
7 return (weight*entmax(input, axis=-1, alpha=1.5, T=20)).sum()
8
----> 9 value_and_grad(toy)(input, weight)
35 frames
<ipython-input-1-d85b1daec668> in entmax(input, axis, alpha, T)
49 @jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
50 def entmax(input, axis=-1, alpha=1.5, T=10):
---> 51 reduce_length = input.shape[axis]
52 input = jnp.swapaxes(input, -1, axis)
53 input = input.reshape(input.size / reduce_length, reduce_length)
TypeError: tuple indices must be integers or slices, not DynamicJaxprTracer
It seems to be always connected to the reshape operations. I am not sure why this happens, and any help will be really appreciated.
To recreate the problem, here is the colab notebook
Thanks a lot.
回答1:
The error comes from the fact that you are attempting to index a Python tuple with a traced quantity, axis
. You can fix this error by making axis
a static argument:
@jax.partial(jit, static_argnums=(0, 1, 2,))
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
...
Unfortunately, this uncovers another problem: p_tau
declares that the alpha
parameter is static, but body()
calls this with a traced quantity. This quantity cannot be easily marked static in body
because it is passed within a dictionary of parameters that contains the input that is being traced.
To fix this, you'll have to rewrite your function signatures, carefully marking in each one which inputs are static and which are not, and making sure the two do not mix across the layers of function calls.
来源:https://stackoverflow.com/questions/65505103/how-to-handle-jax-reshape-with-jit