问题
TL;DR
I want to replicate the functionality of numpy.matmul
in theano
. What's the best way to do this?
Too Short; Didn't Understand
Looking at theano.tensor.dot
and theano.tensor.tensordot
, I'm not seeing an easy way to do a straightforward batch matrix multiplication. i.e. treat the last two dimensions of N dimensional tensors as matrices, and multiply them. Do I need to resort to some goofy usage of theano.tensor.batched_dot
? Or *shudder* loop them myself without broadcasting!?
回答1:
The current pull requests don't support broadcasting, so I came up with this for now. I may clean it up, add a little more functionality, and submit my own PR as a temporary solution. Until then, I hope this helps someone! I included the test to show it replicates numpy.matmul, given that the input complies with my more strict (temporary) assertions.
Also, .scan stops iterating the sequences at argmin(*sequencelengths)
iterations. So, I believe that mismatched array shapes won't raise any exceptions.
import theano as th
import theano.tensor as tt
import numpy as np
def matmul(a: tt.TensorType, b: tt.TensorType, _left=False):
"""Replicates the functionality of numpy.matmul, except that
the two tensors must have the same number of dimensions, and their ndim must exceed 1."""
# TODO ensure that broadcastability is maintained if both a and b are broadcastable on a dim.
assert a.ndim == b.ndim # TODO support broadcasting for differing ndims.
ndim = a.ndim
assert ndim >= 2
# If we should left multiply, just swap references.
if _left:
tmp = a
a = b
b = tmp
# If a and b are 2 dimensional, compute their matrix product.
if ndim == 2:
return tt.dot(a, b)
# If they are larger...
else:
# If a is broadcastable but b is not.
if a.broadcastable[0] and not b.broadcastable[0]:
# Scan b, but hold a steady.
# Because b will be passed in as a, we need to left multiply to maintain
# matrix orientation.
output, _ = th.scan(matmul, sequences=[b], non_sequences=[a[0], 1])
# If b is broadcastable but a is not.
elif b.broadcastable[0] and not a.broadcastable[0]:
# Scan a, but hold b steady.
output, _ = th.scan(matmul, sequences=[a], non_sequences=[b[0]])
# If neither dimension is broadcastable or they both are.
else:
# Scan through the sequences, assuming the shape for this dimension is equal.
output, _ = th.scan(matmul, sequences=[a, b])
return output
def matmul_test() -> bool:
vlist = []
flist = []
ndlist = []
for i in range(2, 30):
dims = int(np.random.random() * 4 + 2)
# Create a tuple of tensors with potentially different broadcastability.
vs = tuple(
tt.TensorVariable(
tt.TensorType('float64',
tuple((p < .3) for p in np.random.ranf(dims-2))
# Make full matrices
+ (False, False)
)
)
for _ in range(2)
)
vs = tuple(tt.swapaxes(v, -2, -1) if j % 2 == 0 else v for j, v in enumerate(vs))
f = th.function([*vs], [matmul(*vs)])
# Create the default shape for the test ndarrays
defshape = tuple(int(np.random.random() * 5 + 1) for _ in range(dims))
# Create a test array matching the broadcastability of each v, for each v.
nds = tuple(
np.random.ranf(
tuple(s if not v.broadcastable[j] else 1 for j, s in enumerate(defshape))
)
for v in vs
)
nds = tuple(np.swapaxes(nd, -2, -1) if j % 2 == 0 else nd for j, nd in enumerate(nds))
ndlist.append(nds)
vlist.append(vs)
flist.append(f)
for i in range(len(ndlist)):
assert np.allclose(flist[i](*ndlist[i]), np.matmul(*ndlist[i]))
return True
if __name__ == "__main__":
print("matmul_test -> " + str(matmul_test()))
来源:https://stackoverflow.com/questions/42169776/numpy-matmul-in-theano