Tensorflow has a function called batch_matmul which multiplies higher dimensional tensors. But I\'m having a hard time understanding how it works, perhaps partially because
You can imagine it as doing a matmul over each training example in the batch.
For example, if you have two tensors with the following dimensions:
a.shape = [100, 2, 5]
b.shape = [100, 5, 2]
and you do a batch tf.matmul(a, b)
, your output will have the shape [100, 2, 2]
.
100 is your batch size, the other two dimensions are the dimensions of your data.