Best way to mimic PyTorch sliced assignment with Keras/Tensorflow

前端 未结 2 655
南旧
南旧 2021-01-21 22:00

I am trying to mimic the operation done in PyTorch below:

vol = Variable(torch.FloatTensor(A, B*2, C, D, E).zero_()).cuda()
for i in range(C):
  if i > 0 :
           


        
2条回答
  •  傲寒
    傲寒 (楼主)
    2021-01-21 23:04

    A tf.Variable is sort of a primitive/basic type. You shouldn't want to gradients to propagate out of them.

    What you want is to construct a node that outputs the 5 dimensional tensor like you want.

    I would run a concatenate operation on the 4th dimension to build the tensor and use the result in place of the vol.

    If you don't care about the gradients propagating to input0 and input1, then I would just build the tensor outside of tensorflow and use it as an initializer.

提交回复
热议问题