I have an autoencoder that takes an image as an input and produces a new image as an output.
The input image (1x1024x1024x3) is split into patches (1024x32x32x3) bef
_,n_row,n_col,n_channel = x.shape
n_patch = n_row*n_col // (patch_size**2) #assume square patch
patches = tf.image.extract_patches(image,sizes=[1,patch_size,patch_size,1],strides=[1,patch_size,patch_size,1],rates=[1, 1, 1, 1],padding='VALID')
patches = tf.reshape(patches,[n_patch,patch_size,patch_size,n_channel])
rows = tf.split(patches,n_col//patch_size,axis=0)
rows = [tf.concat(tf.unstack(x),axis=1) for x in rows]
reconstructed = tf.concat(rows,axis=0)
I don't know if this is an efficient implementation but it works!