numpy.matmul in Theano

强颜欢笑 提交于 2020-01-04 09:38:15

问题


TL;DR
I want to replicate the functionality of numpy.matmul in theano. What's the best way to do this?

Too Short; Didn't Understand
Looking at theano.tensor.dot and theano.tensor.tensordot, I'm not seeing an easy way to do a straightforward batch matrix multiplication. i.e. treat the last two dimensions of N dimensional tensors as matrices, and multiply them. Do I need to resort to some goofy usage of theano.tensor.batched_dot? Or *shudder* loop them myself without broadcasting!?


回答1:


The current pull requests don't support broadcasting, so I came up with this for now. I may clean it up, add a little more functionality, and submit my own PR as a temporary solution. Until then, I hope this helps someone! I included the test to show it replicates numpy.matmul, given that the input complies with my more strict (temporary) assertions.

Also, .scan stops iterating the sequences at argmin(*sequencelengths) iterations. So, I believe that mismatched array shapes won't raise any exceptions.

import theano as th
import theano.tensor as tt
import numpy as np


def matmul(a: tt.TensorType, b: tt.TensorType, _left=False):
    """Replicates the functionality of numpy.matmul, except that
    the two tensors must have the same number of dimensions, and their ndim must exceed 1."""

    # TODO ensure that broadcastability is maintained if both a and b are broadcastable on a dim.

    assert a.ndim == b.ndim  # TODO support broadcasting for differing ndims.
    ndim = a.ndim
    assert ndim >= 2

    # If we should left multiply, just swap references.
    if _left:
        tmp = a
        a = b
        b = tmp

    # If a and b are 2 dimensional, compute their matrix product.
    if ndim == 2:
        return tt.dot(a, b)
    # If they are larger...
    else:
        # If a is broadcastable but b is not.
        if a.broadcastable[0] and not b.broadcastable[0]:
            # Scan b, but hold a steady.
            # Because b will be passed in as a, we need to left multiply to maintain
            #  matrix orientation.
            output, _ = th.scan(matmul, sequences=[b], non_sequences=[a[0], 1])
        # If b is broadcastable but a is not.
        elif b.broadcastable[0] and not a.broadcastable[0]:
            # Scan a, but hold b steady.
            output, _ = th.scan(matmul, sequences=[a], non_sequences=[b[0]])
        # If neither dimension is broadcastable or they both are.
        else:
            # Scan through the sequences, assuming the shape for this dimension is equal.
            output, _ = th.scan(matmul, sequences=[a, b])
        return output


def matmul_test() -> bool:
    vlist = []
    flist = []
    ndlist = []
    for i in range(2, 30):
        dims = int(np.random.random() * 4 + 2)

        # Create a tuple of tensors with potentially different broadcastability.
        vs = tuple(
            tt.TensorVariable(
                tt.TensorType('float64',
                              tuple((p < .3) for p in np.random.ranf(dims-2))
                              # Make full matrices
                              + (False, False)
                )
            )
            for _ in range(2)
        )
        vs = tuple(tt.swapaxes(v, -2, -1) if j % 2 == 0 else v for j, v in enumerate(vs))

        f = th.function([*vs], [matmul(*vs)])

        # Create the default shape for the test ndarrays
        defshape = tuple(int(np.random.random() * 5 + 1) for _ in range(dims))
        # Create a test array matching the broadcastability of each v, for each v.
        nds = tuple(
            np.random.ranf(
                tuple(s if not v.broadcastable[j] else 1 for j, s in enumerate(defshape))
            )
            for v in vs
        )
        nds = tuple(np.swapaxes(nd, -2, -1) if j % 2 == 0 else nd for j, nd in enumerate(nds))

        ndlist.append(nds)
        vlist.append(vs)
        flist.append(f)

    for i in range(len(ndlist)):
        assert np.allclose(flist[i](*ndlist[i]), np.matmul(*ndlist[i]))

    return True


if __name__ == "__main__":
    print("matmul_test -> " + str(matmul_test()))


来源:https://stackoverflow.com/questions/42169776/numpy-matmul-in-theano

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!