I have an autoencoder that takes an image as an input and produces a new image as an output.
The input image (1x1024x1024x3) is split into patches (1024x32x32x3) bef
Since I also struggled with this, I post a solution that might be useful to others. The trick is to realize that the inverse of tf.extract_image_patches
is its gradient, as suggested here. Since the gradient of this op is implemented in Tensorflow, it is easy to build the reconstruction function:
import tensorflow as tf
from keras import backend as K
import numpy as np
def extract_patches(x):
return tf.extract_image_patches(
x,
(1, 3, 3, 1),
(1, 1, 1, 1),
(1, 1, 1, 1),
padding="VALID"
)
def extract_patches_inverse(x, y):
_x = tf.zeros_like(x)
_y = extract_patches(_x)
grad = tf.gradients(_y, _x)[0]
# Divide by grad, to "average" together the overlapping patches
# otherwise they would simply sum up
return tf.gradients(_y, _x, grad_ys=y)[0] / grad
# Generate 10 fake images, last dimension can be different than 3
images = np.random.random((10, 28, 28, 3)).astype(np.float32)
# Extract patches
patches = extract_patches(images)
# Reconstruct image
# Notice that original images are only passed to infer the right shape
images_reconstructed = extract_patches_inverse(images, patches)
# Compare with original (evaluating tf.Tensor into a numpy array)
# Here using Keras session
images_r = images_reconstructed.eval(session=K.get_session())
print (np.sum(np.square(images - images_r)))
# 2.3820458e-11
Use Update#2 - One small example for your task: (TF 1.0)
Considering image of size (4,4,1) converted to patches of size (4,2,2,1) and reconstructed them back to image.
import tensorflow as tf
image = tf.constant([[[1], [2], [3], [4]],
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
[[13], [14], [15], [16]]])
patch_size = [1,2,2,1]
patches = tf.extract_image_patches([image],
patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [4, 2, 2, 1])
reconstructed = tf.reshape(patches, [1, 4, 4, 1])
rec_new = tf.space_to_depth(reconstructed,2)
rec_new = tf.reshape(rec_new,[4,4,1])
sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I)
print(I.shape)
print(P.shape)
print(R_n)
print(R_n.shape)
Output:
[[[ 1][ 2][ 3][ 4]]
[[ 5][ 6][ 7][ 8]]
[[ 9][10][11][12]]
[[13][14][15][16]]]
(4, 4, 1)
(4, 2, 2, 1)
[[[ 1][ 2][ 3][ 4]]
[[ 5][ 6][ 7][ 8]]
[[ 9][10][11][12]]
[[13][14][15][16]]]
(4,4,1)
working only for p = sqrt(h)
import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 32
image = tf.random_normal([h,h,c])
patch_size = [1,p,p,1]
patches = tf.extract_image_patches([image],
patch_size, patch_size, [1, 1, 1, 1], 'VALID')
patches = tf.reshape(patches, [h, p, p, c])
reconstructed = tf.reshape(patches, [1, h, h, c])
rec_new = tf.space_to_depth(reconstructed,p)
rec_new = tf.reshape(rec_new,[h,h,c])
sess = tf.Session()
I,P,R_n = sess.run([image,patches,rec_new])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)
Output :
(1024, 1024, 3)
(1024, 32, 32, 3)
(1024, 1024, 3)
0.0
Reconstructing from output of extract_image_patches seems difficult. Used other functions to extract patches and reverse the process to reconstruct which seems easier.
import tensorflow as tf
import numpy as np
c = 3
h = 1024
p = 128
image = tf.random_normal([1,h,h,c])
# Image to Patches Conversion
pad = [[0,0],[0,0]]
patches = tf.space_to_batch_nd(image,[p,p],pad)
patches = tf.split(patches,p*p,0)
patches = tf.stack(patches,3)
patches = tf.reshape(patches,[(h/p)**2,p,p,c])
# Do processing on patches
# Using patches here to reconstruct
patches_proc = tf.reshape(patches,[1,h/p,h/p,p*p,c])
patches_proc = tf.split(patches_proc,p*p,3)
patches_proc = tf.stack(patches_proc,axis=0)
patches_proc = tf.reshape(patches_proc,[p*p,h/p,h/p,c])
reconstructed = tf.batch_to_space_nd(patches_proc,[p, p],pad)
sess = tf.Session()
I,P,R_n = sess.run([image,patches,reconstructed])
print(I.shape)
print(P.shape)
print(R_n.shape)
err = np.sum((R_n-I)**2)
print(err)
Output:
(1, 1024, 1024, 3)
(64, 128, 128, 3)
(1, 1024, 1024, 3)
0.0
You could see other cool tensor transformation functions here : https://www.tensorflow.org/api_guides/python/array_ops
_,n_row,n_col,n_channel = x.shape
n_patch = n_row*n_col // (patch_size**2) #assume square patch
patches = tf.image.extract_patches(image,sizes=[1,patch_size,patch_size,1],strides=[1,patch_size,patch_size,1],rates=[1, 1, 1, 1],padding='VALID')
patches = tf.reshape(patches,[n_patch,patch_size,patch_size,n_channel])
rows = tf.split(patches,n_col//patch_size,axis=0)
rows = [tf.concat(tf.unstack(x),axis=1) for x in rows]
reconstructed = tf.concat(rows,axis=0)
I don't know if this is an efficient implementation but it works!
Tf 2.0 users can use space_to_depth and depth_to_space if you aren't doing overlapping blocks.
To specifically address the initial question, which is 'Reconstructing an image after using extract_image_patches', I propose using tf.scatter_nd()
and building a stratified image. This will work even in a situation where there is an overlap in the extracted patches or the image is under-sample. Here is my proposed solution.
import cv2
import numpy as np
import tensorflow as tf
# Function to extract patches using 'extract_image_patches'
def img_to_patches(raw_input, _patch_size=(128, 128), _stride=100):
with tf.variable_scope('im2_patches'):
patches = tf.image.extract_image_patches(
images=raw_input,
ksizes=[1, _patch_size[0], _patch_size[1], 1],
strides=[1, _stride, _stride, 1],
rates=[1, 1, 1, 1],
padding='SAME'
)
h = tf.shape(patches)[1]
w = tf.shape(patches)[2]
patches = tf.reshape(patches, (patches.shape[0], -1, _patch_size[0], _patch_size[1], 3))
return patches, (h, w)
# Function to reconstruct image
def patches_to_img(update, _block_shape, _stride=100):
with tf.variable_scope('patches2im'):
_h = _block_shape[0]
_w = _block_shape[1]
bs = tf.shape(update)[0] # batch size
np = tf.shape(update)[1] # number of patches
ps_h = tf.shape(update)[2] # patch height
ps_w = tf.shape(update)[3] # patch width
col_ch = tf.shape(update)[4] # Colour channel count
wout = (_w - 1) * _stride + ps_w # Recalculate output shape of "extract_image_patches" including padded pixels
hout = (_h - 1) * _stride + ps_h # Recalculate output shape of "extract_image_patches" including padded pixels
x, y = tf.meshgrid(tf.range(ps_w), tf.range(ps_h))
x = tf.reshape(x, (1, 1, ps_h, ps_w, 1, 1))
y = tf.reshape(y, (1, 1, ps_h, ps_w, 1, 1))
xstart, ystart = tf.meshgrid(tf.range(0, (wout - ps_w) + 1, _stride),
tf.range(0, (hout - ps_h) + 1, _stride))
bb = tf.zeros((1, np, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(bs), (-1, 1, 1, 1, 1, 1)) # batch indices
yy = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + y + tf.reshape(ystart, (1, -1, 1, 1, 1, 1)) # y indices
xx = tf.zeros((bs, 1, 1, 1, col_ch, 1), dtype=tf.int32) + x + tf.reshape(xstart, (1, -1, 1, 1, 1, 1)) # x indices
cc = tf.zeros((bs, np, ps_h, ps_w, 1, 1), dtype=tf.int32) + tf.reshape(tf.range(col_ch), (1, 1, 1, 1, -1, 1)) # color indices
dd = tf.zeros((bs, 1, ps_h, ps_w, col_ch, 1), dtype=tf.int32) + tf.reshape(tf.range(np), (1, -1, 1, 1, 1, 1)) # shift indices
idx = tf.concat([bb, yy, xx, cc, dd], -1)
stratified_img = tf.scatter_nd(idx, update, (bs, hout, wout, col_ch, np))
stratified_img = tf.transpose(stratified_img, (0, 4, 1, 2, 3))
stratified_img_count = tf.scatter_nd(idx, tf.ones_like(update), (bs, hout, wout, col_ch, np))
stratified_img_count = tf.transpose(stratified_img_count, (0, 4, 1, 2, 3))
with tf.variable_scope("consolidate"):
sum_stratified_img = tf.reduce_sum(stratified_img, axis=1)
stratified_img_count = tf.reduce_sum(stratified_img_count, axis=1)
reconstructed_img = tf.divide(sum_stratified_img, stratified_img_count)
return reconstructed_img, stratified_img
if __name__ == "__main__":
# load initial image
image_org = cv2.imread('orig_img.jpg')
# Add batch dimension
image = np.expand_dims(image_org, axis=0)
# set parameters
patch_size = (228, 228)
stride = 200
input_img = tf.placeholder(dtype=tf.float32, shape=image.shape, name="input_img")
# Extract patches using "extract_image_patches()"
extracted_patches, block_shape = img_to_patches(input_img, _patch_size=patch_size, _stride=stride)
# block_shape is the number of patches extracted in the x and in the y dimension
# extracted_patches.shape = (1, block_shape[0] * block_shape[1], patch_size[0], patch_size[1], 3)
reconstructed_img, stratified_img = patches_to_img(extracted_patches, block_shape, stride) # Reconstruct Image
with tf.Session() as sess:
ep, bs, ri, si = sess.run([extracted_patches, block_shape, reconstructed_img, stratified_img], feed_dict={input_img: image})
# print(bs)
si = si.astype(np.int32)
# Show reconstructed image
cv2.imshow('sd', ri[0, :, :, :].astype(np.float32) / 255)
cv2.waitKey(0)
# Show stratified images
for i in range(si.shape[1]):
im_1 = si[0, i, :, :, :]
cv2.imshow('sd', im_1.astype(np.float32)/255)
The above solution should work for batched images of arbirary color channel dimensions.
This code works for your specific case, as well as for cases when the images are square, with a square kernel and the image size is divisible by the kernel size.
I did not test it for other cases.
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
size = 1024
k_size = 32
axes_1_2_size = int(np.sqrt((size * size) / (k_size * k_size)))
# Define a placeholder for image (or load it directly if you prefer)
img = tf.placeholder(tf.int32, shape=(1, size, size, 3))
# Extract patches
patches = tf.image.extract_image_patches(img, ksizes=[1, k_size, k_size, 1],
strides=[1, k_size, k_size, 1],
rates=[1, 1, 1, 1], padding='VALID')
# Reconstruct the image back from the patches
# First separate out the channel dimension
reconstruct = tf.reshape(patches, (1, axes_1_2_size, axes_1_2_size, k_size, k_size, 3))
# Tranpose the axes (I got this axes tuple for transpose via experimentation)
reconstruct = tf.transpose(reconstruct, (0, 1, 3, 2, 4, 5))
# Reshape back
reconstruct = tf.reshape(reconstruct, (size, size, 3))
im_arr = # load image with shape (size, size, 3)
# Run the operations
with tf.Session() as sess:
ps, r = sess.run([patches, reconstruct], feed_dict={img:[im_arr]})
# Plot the reconstructed image to verify
plt.imshow(r)