Tensorflow: Convolutions with different filter for each sample in the mini-batch

前端 未结 4 1121
伪装坚强ぢ
伪装坚强ぢ 2021-01-05 02:31

I would like to have a 2d convolution with a filter which depends on the sample in the mini-batch in tensorflow. Any ideas how one could do that, especially if the number of

相关标签:
4条回答
  • 2021-01-05 02:35

    The accepted answer is slightly wrong in how it treats the dimensions, as they are changed by padding = "VALID" (he treats them as if padding = "SAME"). Hence in the general case, the code will crash, due to this mismatch. I attach his corrected code, with both scenarios correctly treated.

    inp = tf.placeholder(tf.float32, [MB, H, W, channels_img])
    
    # F has shape (MB, fh, fw, channels, out_channels)
    # REM: with the notation in the question, we need: channels_img==channels
    
    F = tf.transpose(F, [1, 2, 0, 3, 4])
    F = tf.reshape(F, [fh, fw, channels*MB, out_channels)
    
    inp_r = tf.transpose(inp, [1, 2, 0, 3]) # shape (H, W, MB, channels_img)
    inp_r = tf.reshape(inp_r, [1, H, W, MB*channels_img])
    
    padding = "VALID" #or "SAME"
    out = tf.nn.depthwise_conv2d(
              inp_r,
              filter=F,
              strides=[1, 1, 1, 1],
              padding=padding) # here no requirement about padding being 'VALID', use whatever you want. 
    # Now out shape is (1, H-fh+1, W-fw+1, MB*channels*out_channels), because we used "VALID"
    
    if padding == "SAME":
        out = tf.reshape(out, [H, W, MB, channels, out_channels)
    if padding == "VALID":
        out = tf.reshape(out, [H-fh+1, W-fw+1, MB, channels, out_channels)
    out = tf.transpose(out, [2, 0, 1, 3, 4])
    out = tf.reduce_sum(out, axis=3)
    
    # out shape is now (MB, H-fh+1, W-fw+1, out_channels)
    
    0 讨论(0)
  • 2021-01-05 02:39

    I think the proposed trick is actually not right. What happens with a tf.conv3d() layer is that the input gets convolved on depth (=actual batch) dimension AND then summed along resulting feature maps. With padding='SAME' the resulting number of outputs then happens to be the same as batch size so one gets fooled!

    EDIT: I think a possible way to do a convolution with different filters for the different mini-batch elements involves 'hacking' a depthwise convolution. Assuming batch size MB is known:

    inp = tf.placeholder(tf.float32, [MB, H, W, channels_img])
    
    # F has shape (MB, fh, fw, channels, out_channels)
    # REM: with the notation in the question, we need: channels_img==channels
    
    F = tf.transpose(F, [1, 2, 0, 3, 4])
    F = tf.reshape(F, [fh, fw, channels*MB, out_channels)
    
    inp_r = tf.transpose(inp, [1, 2, 0, 3]) # shape (H, W, MB, channels_img)
    inp_r = tf.reshape(inp, [1, H, W, MB*channels_img])
    
    out = tf.nn.depthwise_conv2d(
              inp_r,
              filter=F,
              strides=[1, 1, 1, 1],
              padding='VALID') # here no requirement about padding being 'VALID', use whatever you want. 
    # Now out shape is (1, H, W, MB*channels*out_channels)
    
    out = tf.reshape(out, [H, W, MB, channels, out_channels) # careful about the order of depthwise conv out_channels!
    out = tf.transpose(out, [2, 0, 1, 3, 4])
    out = tf.reduce_sum(out, axis=3)
    
    # out shape is now (MB, H, W, out_channels)
    

    In case MB is unknown, it should be possible to determine it dynamically using tf.shape() (I think)

    0 讨论(0)
  • 2021-01-05 02:41

    They way to go around it is adding an extra dimension using

    tf.expand_dims(inp, 0)
    

    to create a 'fake' batch size. Then use the

    tf.nn.conv3d()
    

    operation where the filter-depth matches the batch size. This will result in each filter convolving with only one sample in each batch.

    Sadly, you will not solve the variable batch size problem this way, only the convolutions.

    0 讨论(0)
  • 2021-01-05 02:47

    You could use tf.map_fn as follows:

    inp = tf.placeholder(tf.float32, [None, h, w, c_in]) 
    def single_conv(tupl):
        x, kernel = tupl
        return tf.nn.conv2d(x, kernel, strides=(1, 1, 1, 1), padding='VALID')
    # Assume kernels shape is [tf.shape(inp)[0], fh, fw, c_in, c_out]
    batch_wise_conv = tf.squeeze(tf.map_fn(
        single_conv, (tf.expand_dims(inp, 1), kernels), dtype=tf.float32),
        axis=1
    )
    

    It is important to specify dtype for map_fn. Basically, this solution defines batch_dim_size 2D convolution operations.

    0 讨论(0)
提交回复
热议问题