Getting started with Tensorflow - Split image into sub-images

一曲冷凌霜 提交于 2019-12-21 05:43:09

问题


This is my very first time using a Convolutional Neural Networks and Tensorflow.

I am trying to implement a convolutional neural network that is able to extract vessels from Digital Retinal Images. I am working with the publicly available Drive database (images are in .tif format).

Since my images are very large my idea is to split them into sub-images of size 28x28x1 (The "1" is the green channel, the only one I need). To create the training set I randomly crop a 28x28 batch iteratively from each image, and train the network on this set.

Now, I would like to test my trained network on one of the large images in the database (that is, I want to apply the network to a complete eye). Since my network is trained on sub-images of size 28x28 the idea is to split the eye in 'n' sub-images, pass them throw the network, reassemble them and show the result as show in Fig1:

Fig1

I tried using some functions like: tf.extract_image_pathces or tf.train.batch, but I would like to know what is the right method to do this.

Below is a snippet of my code. The function where I am stuck is split_image(image)

import numpy
import os
import random

from PIL import Image
import tensorflow as tf

BATCH_WIDTH = 28;
BATCH_HEIGHT = 28;

NUM_TRIALS = 10;

class Drive:
    def __init__(self,train):
        self.train = train

class Dataset:
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels
        self.current_batch = 0

    def next_batch(self):
        batch = self.inputs[self.current_batch], self.labels[self.current_batch]
        self.current_batch = (self.current_batch + 1) % len(self.inputs)
        return batch


#counts the number of black pixel in the batch
def mostlyBlack(image):
    pixels = image.getdata()
    black_thresh = 50
    nblack = 0
    for pixel in pixels:
        if pixel < black_thresh:
            nblack += 1

    return nblack / float(len(pixels)) > 0.5

#crop the image starting from a random point
def cropImage(image, label):
    width  = image.size[0]
    height = image.size[1]
    x = random.randrange(0, width - BATCH_WIDTH)
    y = random.randrange(0, height - BATCH_HEIGHT)
    image = image.crop((x, y, x + BATCH_WIDTH, y + BATCH_HEIGHT)).split()[1]
    label = label.crop((x, y, x + BATCH_WIDTH, y + BATCH_HEIGHT)).split()[0]
    return image, label

def split_image(image):

    ksizes_ = [1, BATCH_WIDTH, BATCH_HEIGHT, 1]
    strides_ = [1, BATCH_WIDTH, BATCH_HEIGHT, 1]

    input = numpy.array(image.split()[1])
    #input = tf.reshape((input), [image.size[0], image.size[1]])

    #input = tf.train.batch([input],batch_size=1)
    split = tf.extract_image_patches(input, padding='VALID', ksizes=ksizes_, strides=strides_, rates=[1,28,28,1], name="asdk")

#creates NUM_TRIALS images from a dataset
def create_dataset(images_path, label_path):
    files = os.listdir(images_path)
    label_files = os.listdir(label_path)

    images = [];
    labels = [];
    t = 0
    while t < NUM_TRIALS:
        index = random.randrange(0, len(files))
        if files[index].endswith(".tif"):
            image_filename = images_path + files[index]
            label_filename = label_path  + label_files[index]
            image = Image.open(image_filename)
            label = Image.open(label_filename)
            image, label = cropImage(image, label)
            if not mostlyBlack(image):
                #images.append(tf.convert_to_tensor(numpy.array(image)))
                #labels.append(tf.convert_to_tensor(numpy.array(label)))
                images.append(numpy.array(image))
                labels.append(numpy.array(label))

                t+=1

    image = Image.open(images_path + files[1])
    split_image(image)

    train = Dataset(images, labels)
    return Drive(train)

回答1:


You can use a combination of reshape and transpose calls to cut an image into tiles:

def split_image(image3, tile_size):
    image_shape = tf.shape(image3)
    tile_rows = tf.reshape(image3, [image_shape[0], -1, tile_size[1], image_shape[2]])
    serial_tiles = tf.transpose(tile_rows, [1, 0, 2, 3])
    return tf.reshape(serial_tiles, [-1, tile_size[1], tile_size[0], image_shape[2]])

where image3 is a 3-dimensional tensor (e.g. an image), and tile_size is a pair of values [H, W] specifying the size of a tile. The output is a tensor with shape [B, H, W, C]. In your case the call would be:

tiles = split_image(image, [28, 28])

resulting in a tensor with shape [B, 28, 28, 1]. You can also reassemble the original image from the tiles by performing these operations in reverse:

def unsplit_image(tiles4, image_shape):
    tile_width = tf.shape(tiles4)[1]
    serialized_tiles = tf.reshape(tiles4, [-1, image_shape[0], tile_width, image_shape[2]])
    rowwise_tiles = tf.transpose(serialized_tiles, [1, 0, 2, 3])
    return tf.reshape(rowwise_tiles, [image_shape[0], image_shape[1], image_shape[2]]))

Where tiles4 is a 4D tensor of shape [B, H, W, C], and image_shape is the shape of the original image. In your case the call could be:

image = unsplit_image(tiles, tf.shape(image))

Note that this only works if the image size is divisible by the tile size. If that's not the case you need to pad your image to the nearest multiple of the tile size:

def pad_image_to_tile_multiple(image3, tile_size, padding="CONSTANT"):
    imagesize = tf.shape(image3)[0:2]
    padding_ = tf.to_int32(tf.ceil(imagesize / tile_size)) * tile_size - imagesize
    return tf.pad(image3, [[0, padding_[0]], [0, padding_[1]], [0, 0]], padding)

Which you would call as such:

image = pad_image_to_tile_multiple(image, [28,28])

Then remove the paddig by splicing after you reassembled the image from tiles:

image = image[0:original_size[0], 0:original_size[1], :]


来源:https://stackoverflow.com/questions/38235643/getting-started-with-tensorflow-split-image-into-sub-images

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!