I have a 3D image with shape DxHxW
. I was successful to extract the image into patches pdxphxpw
(overlapping patches). For each patch, I do some pro
This will do the reverse, however, since your patches overlap this will only be well-defined if their values agree where they overlap
def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
out = np.zeros(out_shape, patches.dtype)
patch_shape = patches.shape[-3:]
patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
(out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
patches_6D[...] = patches.reshape(patches_6D.shape)
return out
Update: here is a safer version that averages overlapping pixels:
def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
out = np.zeros(out_shape, patches.dtype)
denom = np.zeros(out_shape, patches.dtype)
patch_shape = patches.shape[-3:]
patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
(out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
denom_6D = np.lib.stride_tricks.as_strided(denom, ((denom.shape[0] - patch_shape[0] + 1) // xstep, (denom.shape[1] - patch_shape[1] + 1) // ystep,
(denom.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
(denom.strides[0] * xstep, denom.strides[1] * ystep,denom.strides[2] * zstep, denom.strides[0], denom.strides[1],denom.strides[2]))
np.add.at(patches_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), patches.ravel())
np.add.at(denom_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), 1)
return out/denom