How to explicitly broadcast a tensor to match another's shape in tensorflow?

前端 未结 4 1834
爱一瞬间的悲伤
爱一瞬间的悲伤 2021-02-13 09:17

I have three tensors, A, B and C in tensorflow, A and B are both of shape (m, n, r), C is a binary tensor of sha

4条回答
  •  心在旅途
    2021-02-13 09:44

    Here's a dirty hack:

    import tensorflow as tf
    
    def broadcast(tensor, shape):
        return tensor + tf.zeros(shape, dtype=tensor.dtype)
    
    A = tf.random_normal([20, 100, 10])
    B = tf.random_normal([20, 100, 10])
    C = tf.random_normal([20, 100, 1])
    
    C = broadcast(C, A.shape)
    D = tf.select(C, A, B)
    

提交回复
热议问题