How to handle JAX reshape with JIT

核能气质少年 提交于 2021-01-28 22:01:59

问题


I am trying to implement entmax-alpha as is described in here.

Here is the code.

import jax
import jax.numpy as jnp
from jax import custom_jvp
from jax import jit
from jax import lax
from jax import vmap


@jax.partial(jit, static_argnums=(2,))
def p_tau(z, tau, alpha=1.5):
    return jnp.clip((alpha - 1) * z - tau, a_min=0) ** (1 / (alpha - 1))


@jit
def get_tau(tau, tau_max, tau_min, z_value):
    return lax.cond(z_value < 1,
                    lambda _: (tau, tau_min),
                    lambda _: (tau_max, tau),
                    operand=None
                    )


@jit
def body(kwargs, x):
    tau_min = kwargs['tau_min']
    tau_max = kwargs['tau_max']
    z = kwargs['z']
    alpha = kwargs['alpha']

    tau = (tau_min + tau_max) / 2
    z_value = p_tau(z, tau, alpha).sum()
    taus = get_tau(tau, tau_max, tau_min, z_value)
    tau_max, tau_min = taus[0], taus[1]
    return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None


@jax.partial(jit, static_argnums=(1, 2,))
def map_row(z_input, alpha, T):
    z = (alpha - 1) * z_input

    tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
    result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
                         length=T)
    tau = (result['tau_max'] + result['tau_min']) / 2
    result = p_tau(z, tau, alpha)
    return result / result.sum()


@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
    reduce_length = input.shape[axis]
    input = jnp.swapaxes(input, -1, axis)
    input = input.reshape(input.size / reduce_length, reduce_length)
    result = vmap(jax.partial(map_row, alpha=alpha, T=T), 0)(input)
    return jnp.swapaxes(result, -1, axis)


@jax.partial(jit, static_argnums=(1, 2,))
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
    input = primals[0]
    Y = entmax(input, axis, alpha, T)
    gppr = Y ** (2 - alpha)
    grad_output = tangents[0]
    dX = grad_output * gppr
    q = dX.sum(axis=axis) / gppr.sum(axis=axis)
    q = jnp.expand_dims(q, axis=axis)
    dX -= q * gppr
    return Y, dX


@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
    return _entmax_jvp_impl(axis, alpha, T, primals, tangents)

When I call it with the following code:

import numpy as np
from jax import value_and_grad
input = jnp.array(np.random.randn(64, 10))
weight = jnp.array(np.random.randn(64, 10))

def toy(input, weight):
    return (weight*entmax(input, axis=-1, alpha=1.5, T=20)).sum()

value_and_grad(toy)(input, weight)

I got the following error.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-3a62e54c67d2> in <module>()
      7     return (weight*entmax(input, axis=-1, alpha=1.5, T=20)).sum()
      8 
----> 9 value_and_grad(toy)(input, weight)

35 frames
<ipython-input-1-d85b1daec668> in entmax(input, axis, alpha, T)
     49 @jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
     50 def entmax(input, axis=-1, alpha=1.5, T=10):
---> 51     reduce_length = input.shape[axis]
     52     input = jnp.swapaxes(input, -1, axis)
     53     input = input.reshape(input.size / reduce_length, reduce_length)

TypeError: tuple indices must be integers or slices, not DynamicJaxprTracer

It seems to be always connected to the reshape operations. I am not sure why this happens, and any help will be really appreciated.

To recreate the problem, here is the colab notebook

Thanks a lot.


回答1:


The error comes from the fact that you are attempting to index a Python tuple with a traced quantity, axis. You can fix this error by making axis a static argument:

@jax.partial(jit, static_argnums=(0, 1, 2,))
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
    ...

Unfortunately, this uncovers another problem: p_tau declares that the alpha parameter is static, but body() calls this with a traced quantity. This quantity cannot be easily marked static in body because it is passed within a dictionary of parameters that contains the input that is being traced.

To fix this, you'll have to rewrite your function signatures, carefully marking in each one which inputs are static and which are not, and making sure the two do not mix across the layers of function calls.



来源:https://stackoverflow.com/questions/65505103/how-to-handle-jax-reshape-with-jit

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!