问题
I recently switched from numpy to ND4J but had a hard time understanding how broadcasting in ND4J works.
Say I have two ndarray, a of shape [3,2,4,5] and b of shape [2,4,5]. I would like to element-wise add them up and broadcast b to each a[i] for i = 0 to 2
. In numpy it can simply be done via a + b
, while in ND4J a.add(b)
throws an exception. I tried a.add(b.broadcast(3))
but still no luck.
What is the correct way of doing this in ND4J?
回答1:
The only method I have found so far is as following
var a = Nd4j.createUninitialized(Array(3,2,4,4))
var b = Nd4j.createUninitialized(Array(2,4,4))
b = b.reshape(1,32)
b = b.broadcast(3,32)
b = b.reshape(3, 2, 4, 4)
a.add(b)
Please kindly let me know if there is a better way to do this
来源:https://stackoverflow.com/questions/42309075/using-broadcast-in-nd4j