How to recover 3D image from its patches in Python?

后端 未结 1 1750
别那么骄傲
别那么骄傲 2021-01-16 03:36

I have a 3D image with shape DxHxW. I was successful to extract the image into patches pdxphxpw(overlapping patches). For each patch, I do some pro

1条回答
  •  别那么骄傲
    2021-01-16 03:51

    This will do the reverse, however, since your patches overlap this will only be well-defined if their values agree where they overlap

    def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
        out = np.zeros(out_shape, patches.dtype)
        patch_shape = patches.shape[-3:]
        patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
                                                      (out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
                                                      (out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
        patches_6D[...] = patches.reshape(patches_6D.shape)
        return out
    

    Update: here is a safer version that averages overlapping pixels:

    def stuff_patches_3D(out_shape,patches,xstep=12,ystep=12,zstep=12):
        out = np.zeros(out_shape, patches.dtype)
        denom = np.zeros(out_shape, patches.dtype)
        patch_shape = patches.shape[-3:]
        patches_6D = np.lib.stride_tricks.as_strided(out, ((out.shape[0] - patch_shape[0] + 1) // xstep, (out.shape[1] - patch_shape[1] + 1) // ystep,
                                                      (out.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
                                                      (out.strides[0] * xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2]))
        denom_6D = np.lib.stride_tricks.as_strided(denom, ((denom.shape[0] - patch_shape[0] + 1) // xstep, (denom.shape[1] - patch_shape[1] + 1) // ystep,
                                                      (denom.shape[2] - patch_shape[2] + 1) // zstep, patch_shape[0], patch_shape[1], patch_shape[2]),
                                                      (denom.strides[0] * xstep, denom.strides[1] * ystep,denom.strides[2] * zstep, denom.strides[0], denom.strides[1],denom.strides[2]))
        np.add.at(patches_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), patches.ravel())
        np.add.at(denom_6D, tuple(x.ravel() for x in np.indices(patches_6D.shape)), 1)
        return out/denom
    

    0 讨论(0)
提交回复
热议问题