Tensorflow: Convolutions with different filter for each sample in the mini-batch

前端 未结 4 1123
伪装坚强ぢ
伪装坚强ぢ 2021-01-05 02:31

I would like to have a 2d convolution with a filter which depends on the sample in the mini-batch in tensorflow. Any ideas how one could do that, especially if the number of

4条回答
  •  心在旅途
    2021-01-05 02:35

    The accepted answer is slightly wrong in how it treats the dimensions, as they are changed by padding = "VALID" (he treats them as if padding = "SAME"). Hence in the general case, the code will crash, due to this mismatch. I attach his corrected code, with both scenarios correctly treated.

    inp = tf.placeholder(tf.float32, [MB, H, W, channels_img])
    
    # F has shape (MB, fh, fw, channels, out_channels)
    # REM: with the notation in the question, we need: channels_img==channels
    
    F = tf.transpose(F, [1, 2, 0, 3, 4])
    F = tf.reshape(F, [fh, fw, channels*MB, out_channels)
    
    inp_r = tf.transpose(inp, [1, 2, 0, 3]) # shape (H, W, MB, channels_img)
    inp_r = tf.reshape(inp_r, [1, H, W, MB*channels_img])
    
    padding = "VALID" #or "SAME"
    out = tf.nn.depthwise_conv2d(
              inp_r,
              filter=F,
              strides=[1, 1, 1, 1],
              padding=padding) # here no requirement about padding being 'VALID', use whatever you want. 
    # Now out shape is (1, H-fh+1, W-fw+1, MB*channels*out_channels), because we used "VALID"
    
    if padding == "SAME":
        out = tf.reshape(out, [H, W, MB, channels, out_channels)
    if padding == "VALID":
        out = tf.reshape(out, [H-fh+1, W-fw+1, MB, channels, out_channels)
    out = tf.transpose(out, [2, 0, 1, 3, 4])
    out = tf.reduce_sum(out, axis=3)
    
    # out shape is now (MB, H-fh+1, W-fw+1, out_channels)
    

提交回复
热议问题