Reconstructing an image after using extract_image_patches

后端 未结 7 1276
一生所求
一生所求 2020-12-16 14:37

I have an autoencoder that takes an image as an input and produces a new image as an output.

The input image (1x1024x1024x3) is split into patches (1024x32x32x3) bef

相关标签:
7条回答
  • 2020-12-16 15:12

    Since I also struggled with this, I post a solution that might be useful to others. The trick is to realize that the inverse of tf.extract_image_patches is its gradient, as suggested here. Since the gradient of this op is implemented in Tensorflow, it is easy to build the reconstruction function:

    import tensorflow as tf
    from keras import backend as K
    import numpy as np
    
    def extract_patches(x):
        return tf.extract_image_patches(
            x,
            (1, 3, 3, 1),
            (1, 1, 1, 1),
            (1, 1, 1, 1),
            padding="VALID"
        )
    
    def extract_patches_inverse(x, y):
        _x = tf.zeros_like(x)
        _y = extract_patches(_x)
        grad = tf.gradients(_y, _x)[0]
        # Divide by grad, to "average" together the overlapping patches
        # otherwise they would simply sum up
        return tf.gradients(_y, _x, grad_ys=y)[0] / grad
    
    # Generate 10 fake images, last dimension can be different than 3
    images = np.random.random((10, 28, 28, 3)).astype(np.float32)
    # Extract patches
    patches = extract_patches(images)
    # Reconstruct image
    # Notice that original images are only passed to infer the right shape
    images_reconstructed = extract_patches_inverse(images, patches) 
    
    # Compare with original (evaluating tf.Tensor into a numpy array)
    # Here using Keras session
    images_r = images_reconstructed.eval(session=K.get_session())
    
    print (np.sum(np.square(images - images_r))) 
    # 2.3820458e-11
    
    0 讨论(0)
  • 2020-12-16 15:20

    Use Update#2 - One small example for your task: (TF 1.0)

    Considering image of size (4,4,1) converted to patches of size (4,2,2,1) and reconstructed them back to image.

    import tensorflow as tf
    image = tf.constant([[[1],   [2],  [3],  [4]],
                     [[5],   [6],  [7],  [8]],
                     [[9],  [10], [11],  [12]],
                    [[13], [14], [15],  [16]]])
    
    patch_size = [1,2,2,1]
    patches = tf.extract_image_patches([image],
        patch_size, patch_size, [1, 1, 1, 1], 'VALID')
    patches = tf.reshape(patches, [4, 2, 2, 1])
    reconstructed = tf.reshape(patches, [1, 4, 4, 1])
    rec_new = tf.space_to_depth(reconstructed,2)
    rec_new = tf.reshape(rec_new,[4,4,1])
    
    sess = tf.Session()
    I,P,R_n = sess.run([image,patches,rec_new])
    print(I)
    print(I.shape)
    print(P.shape)
    print(R_n)
    print(R_n.shape)
    

    Output:

    [[[ 1][ 2][ 3][ 4]]
      [[ 5][ 6][ 7][ 8]]
      [[ 9][10][11][12]]
      [[13][14][15][16]]]
    (4, 4, 1)
    (4, 2, 2, 1)
    [[[ 1][ 2][ 3][ 4]]
      [[ 5][ 6][ 7][ 8]]
      [[ 9][10][11][12]]
      [[13][14][15][16]]]
    (4,4,1)
    

    Update - for 3 channels (debugging..)

    working only for p = sqrt(h)

    import tensorflow as tf
    import numpy as np
    c = 3
    h = 1024
    p = 32
    
    image = tf.random_normal([h,h,c])
    patch_size = [1,p,p,1]
    patches = tf.extract_image_patches([image],
       patch_size, patch_size, [1, 1, 1, 1], 'VALID')
    patches = tf.reshape(patches, [h, p, p, c])
    reconstructed = tf.reshape(patches, [1, h, h, c])
    rec_new = tf.space_to_depth(reconstructed,p)
    rec_new = tf.reshape(rec_new,[h,h,c])
    
    sess = tf.Session()
    I,P,R_n = sess.run([image,patches,rec_new])
    print(I.shape)
    print(P.shape)
    print(R_n.shape)
    err = np.sum((R_n-I)**2)
    print(err)
    

    Output :

    (1024, 1024, 3)
    (1024, 32, 32, 3)
    (1024, 1024, 3)
    0.0
    

    Update 2

    Reconstructing from output of extract_image_patches seems difficult. Used other functions to extract patches and reverse the process to reconstruct which seems easier.

    import tensorflow as tf
    import numpy as np
    c = 3
    h = 1024
    p = 128
    
    
    image = tf.random_normal([1,h,h,c])
    
    # Image to Patches Conversion
    pad = [[0,0],[0,0]]
    patches = tf.space_to_batch_nd(image,[p,p],pad)
    patches = tf.split(patches,p*p,0)
    patches = tf.stack(patches,3)
    patches = tf.reshape(patches,[(h/p)**2,p,p,c])
    
    # Do processing on patches
    # Using patches here to reconstruct
    patches_proc = tf.reshape(patches,[1,h/p,h/p,p*p,c])
    patches_proc = tf.split(patches_proc,p*p,3)
    patches_proc = tf.stack(patches_proc,axis=0)
    patches_proc = tf.reshape(patches_proc,[p*p,h/p,h/p,c])
    
    reconstructed = tf.batch_to_space_nd(patches_proc,[p, p],pad)
    
    sess = tf.Session()
    I,P,R_n = sess.run([image,patches,reconstructed])
    print(I.shape)
    print(P.shape)
    print(R_n.shape)
    err = np.sum((R_n-I)**2)
    print(err)
    

    Output:

    (1, 1024, 1024, 3)
    (64, 128, 128, 3)
    (1, 1024, 1024, 3)
    0.0
    

    You could see other cool tensor transformation functions here : https://www.tensorflow.org/api_guides/python/array_ops

    0 讨论(0)
  • 2020-12-16 15:20
    _,n_row,n_col,n_channel = x.shape
    n_patch = n_row*n_col // (patch_size**2) #assume square patch
    
    patches = tf.image.extract_patches(image,sizes=[1,patch_size,patch_size,1],strides=[1,patch_size,patch_size,1],rates=[1, 1, 1, 1],padding='VALID')
    patches = tf.reshape(patches,[n_patch,patch_size,patch_size,n_channel])
    
    rows = tf.split(patches,n_col//patch_size,axis=0)
    rows = [tf.concat(tf.unstack(x),axis=1) for x in rows] 
    
    reconstructed = tf.concat(rows,axis=0)
    

    I don't know if this is an efficient implementation but it works!

    0 讨论(0)
  • 2020-12-16 15:24

    Tf 2.0 users can use space_to_depth and depth_to_space if you aren't doing overlapping blocks.

    0 讨论(0)
  • 2020-12-16 15:26

    To specifically address the initial question, which is 'Reconstructing an image after using extract_image_patches', I propose using tf.scatter_nd() and building a stratified image. This will work even in a situation where there is an overlap in the extracted patches or the image is under-sample. Here is my proposed solution.

    import cv2
    import numpy as np
    import tensorflow as tf
    
    # Function to extract patches using 'extract_image_patches'
    def img_to_patches(raw_input, _patch_size=(128, 128), _stride=100):
    
        with tf.variable_scope('im2_patches'):
            patches = tf.image.extract_image_patches(
                images=raw_input,
                ksizes=[1, _patch_size[0], _patch_size[1], 1],
                strides=[1, _stride, _stride, 1],
                rates=[1, 1, 1, 1],
                padding='SAME'
            )
    
            h = tf.shape(patches)[1]
            w = tf.shape(patches)[2]
            patches = tf.reshape(patches, (patches.shape[0], -1, _patch_size[0], _patch_size[1], 3))
        return patches, (h, w)
    
    
    # Function to reconstruct image
    def patches_to_img(update, _block_shape, _stride=100):
        with tf.variable_scope('patches2im'):
            _h = _block_shape[0]
            _w = _block_shape[1]
    
            bs = tf.shape(update)[0]  # batch size
            np = tf.shape(update)[1]  # number of patches
            ps_h = tf.shape(update)[2]  # patch height
            ps_w = tf.shape(update)[3]  # patch width
            col_ch = tf.shape(update)[4]  # Colour channel count
    
            wout = (_w - 1) * _stride + ps_w  # Recalculate output shape of "extract_image_patches" including padded pixels
            hout = (_h - 1) * _stride + ps_h  # Recalculate output shape of "extract_image_patches" including padded pixels
    
            x, y = tf.meshgrid(tf.range(ps_w), tf.range(ps_h))
            x = tf.reshape(x, (1, 1, ps_h, ps_w, 1, 1))
            y = tf.reshape(y, (1, 1, ps_h, ps_w, 1, 1))
            xstart, ystart = tf.meshgrid(tf.range(0, (wout - ps_w) + 1, _stride),
                                         tf.range(0, (hout - ps_h) + 1, _stride))
    
            bb = tf.zeros((1, np, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(bs), (-1, 1, 1, 1, 1, 1))  #  batch indices
            yy = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + y + tf.reshape(ystart, (1, -1, 1, 1, 1, 1))  # y indices
            xx = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + x + tf.reshape(xstart, (1, -1, 1, 1, 1, 1))  # x indices
            cc = tf.zeros((bs, np, ps_h, ps_w, 1, 1), dtype=tf.int32) + tf.reshape(tf.range(col_ch), (1, 1, 1, 1, -1, 1))  # color indices
            dd = tf.zeros((bs, 1, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(np), (1, -1, 1, 1, 1, 1))  # shift indices
    
            idx = tf.concat([bb, yy, xx, cc, dd], -1)
    
            stratified_img = tf.scatter_nd(idx, update, (bs, hout, wout, col_ch, np))
            stratified_img = tf.transpose(stratified_img, (0, 4, 1, 2, 3))
    
            stratified_img_count = tf.scatter_nd(idx, tf.ones_like(update), (bs, hout, wout, col_ch, np))
            stratified_img_count = tf.transpose(stratified_img_count, (0, 4, 1, 2, 3))
    
            with tf.variable_scope("consolidate"):
                sum_stratified_img = tf.reduce_sum(stratified_img, axis=1)
                stratified_img_count = tf.reduce_sum(stratified_img_count, axis=1)
                reconstructed_img = tf.divide(sum_stratified_img, stratified_img_count)
    
            return reconstructed_img, stratified_img
    
    
    
    if __name__ == "__main__":
    
        # load initial image
        image_org = cv2.imread('orig_img.jpg')
        # Add batch dimension
        image = np.expand_dims(image_org, axis=0)
    
        # set parameters
        patch_size = (228, 228)
        stride = 200
    
        input_img = tf.placeholder(dtype=tf.float32, shape=image.shape, name="input_img")
    
        # Extract patches using "extract_image_patches()"
        extracted_patches, block_shape = img_to_patches(input_img, _patch_size=patch_size, _stride=stride)
        # block_shape is the number of patches extracted in the x and in the y dimension
        # extracted_patches.shape = (1, block_shape[0] * block_shape[1], patch_size[0], patch_size[1], 3)
    
        reconstructed_img, stratified_img = patches_to_img(extracted_patches, block_shape, stride)  # Reconstruct Image
    
    
        with tf.Session() as sess:
            ep, bs, ri, si = sess.run([extracted_patches, block_shape, reconstructed_img, stratified_img], feed_dict={input_img: image})
            # print(bs)
        si = si.astype(np.int32)
    
        # Show reconstructed image
        cv2.imshow('sd', ri[0, :, :, :].astype(np.float32) / 255)
        cv2.waitKey(0)
    
        # Show stratified images
        for i in range(si.shape[1]):
    
            im_1 = si[0, i, :, :, :]
            cv2.imshow('sd', im_1.astype(np.float32)/255)
    

    The above solution should work for batched images of arbirary color channel dimensions.

    0 讨论(0)
  • 2020-12-16 15:30

    This code works for your specific case, as well as for cases when the images are square, with a square kernel and the image size is divisible by the kernel size.

    I did not test it for other cases.

    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    
    
    size = 1024
    k_size = 32
    axes_1_2_size = int(np.sqrt((size * size) / (k_size * k_size)))
    
    # Define a placeholder for image (or load it directly if you prefer) 
    img = tf.placeholder(tf.int32, shape=(1, size, size, 3))
    
    # Extract patches
    patches = tf.image.extract_image_patches(img, ksizes=[1, k_size, k_size, 1], 
                                             strides=[1, k_size, k_size, 1], 
                                             rates=[1, 1, 1, 1], padding='VALID')
    
    # Reconstruct the image back from the patches
    # First separate out the channel dimension
    reconstruct = tf.reshape(patches, (1, axes_1_2_size, axes_1_2_size, k_size, k_size, 3)) 
    # Tranpose the axes (I got this axes tuple for transpose via experimentation)
    reconstruct = tf.transpose(reconstruct, (0, 1, 3, 2, 4, 5))
    # Reshape back
    reconstruct = tf.reshape(reconstruct, (size, size, 3))
    
    im_arr = # load image with shape (size, size, 3)
    
    # Run the operations
    with tf.Session() as sess:
        ps, r = sess.run([patches, reconstruct], feed_dict={img:[im_arr]})
    
    # Plot the reconstructed image to verify
    plt.imshow(r)
    
    0 讨论(0)
提交回复
热议问题