Why CIFAR-10 images are not displayed properly using matplotlib?

前端 未结 9 841
独厮守ぢ
独厮守ぢ 2021-02-06 12:28

From the training set I took a image(\'img\') of size (3,32,32). I have used plt.imshow(img.T). The image is not clear. Now changes I have to make to image(\'img\') to make it m

相关标签:
9条回答
  • 2021-02-06 13:08

    Following prints 5X5 grid of random Cifar10 images. It isn't blurry, though not perfect either. Any suggestions welcome.

    %matplotlib inline
    import numpy as np
    import matplotlib.pyplot as plt
    from six.moves import cPickle 
    
    f = open('data/cifar10/cifar-10-batches-py/data_batch_1', 'rb')
    datadict = cPickle.load(f,encoding='latin1')
    f.close()
    X = datadict["data"] 
    Y = datadict['labels']
    X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("uint8")
    Y = np.array(Y)
    
    #Visualizing CIFAR 10
    fig, axes1 = plt.subplots(5,5,figsize=(3,3))
    for j in range(5):
        for k in range(5):
            i = np.random.choice(range(len(X)))
            axes1[j][k].set_axis_off()
            axes1[j][k].imshow(X[i:i+1][0])
    
    0 讨论(0)
  • 2021-02-06 13:09

    I have used the following code to show all CIFAR data as one big image. The code show the image, but if you want to save it and not be blurtry i sugest using plt.savefig(fname, format='png', dpi=1000)

    import numpy as np
    import matplotlib.pyplot as plt
    
    def reshape_and_print(self, cifar_data):
        # number of images in rows and columns
        rows = cols = np.sqrt(cifar_data.shape[0]).astype(np.int32)
        # Image hight and width. Divide by 3 because of 3 color channels
        imh = imw = np.sqrt(cifar_data.shape[1] // 3).astype(np.int32)
        # reshape to number of images X color channels X image size
        # transpose to color channels X number of images X image size
        timg = cifar_data.reshape(rows * cols, 3, imh * imh).transpose(1, 0, 2)
        # reshape to color channels X rows X cols X image hight X image with
        # swap axis to color channels X rows X image hight X cols X image with
        timg = timg.reshape(3, rows, cols, imh, imw).swapaxes(2, 3)
        # reshape to color channels X combined image hight X combined image with
        # transpose to combined image hight X combined image with X color channels
        timg = timg.reshape(3, rows * imh, cols * imw).transpose(1, 2, 0)
    
        plt.imshow(timg)
        plt.show()
    

    I made a quick data helper class that i used for a small test project, I hope is can be useful:

    import gzip
    import pickle
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    class DataSet(object):
    
        def __init__(self, seed=42, setsize=10000):
            self.seed = seed
            # set the seed for reproducability
            np.random.seed(seed)
            # load the data
            train_set, test_set = self.load_data()
            # self.split_data(train_set, valid_set, test_set)
            self.split_data(train_set, test_set, setsize)
    
        def split_data(self, data_set, test_set, split_size):
            permutation = np.random.permutation(data_set.shape[0])
            self.train = data_set[permutation[:split_size]]
            self.valid = data_set[permutation[split_size:split_size * 2]]
            self.test = test_set[:split_size]
    
        def reshape_for_print(self, data):
            raise NotImplemented
    
        def load_data(self):
            raise NotImplemented
    
        def show_all_imgs(self, data):
            raise NotImplemented
    
    
    class CIFAR(DataSet):
    
        def load_data(self):
            # try to load data
            with open('./data/cifar-100-python/train', 'rb') as f:
                data = pickle.load(f, encoding='latin1')
            train_set = data['data'].astype(np.float32) / 255.0
    
            with open('./data/cifar-100-python/test', 'rb') as f:
                data = pickle.load(f, encoding='latin1')
            test_set = data['data'].astype(np.float32) / 255.0
    
            return train_set, test_set
    
        def reshape_for_print(self, data):
            gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
            imh = imw = np.sqrt(data.shape[1] // 3).astype(np.int32)
            timg = data.reshape(gh * gw, 3, imh * imh).transpose(1, 0, 2)
            timg = timg.reshape(3, gh, gw, imh, imw).swapaxes(2, 3)
            timg = timg.reshape(3, gh * imh, gw * imw).transpose(1, 2, 0)
            return timg
    
        def show_all_imgs(self, data):
            timg = self.reshape_for_print(data)
            plt.imshow(timg)
            plt.show()
    
    
    class MNIST(DataSet):
    
        def load_data(self):
            # try to load data
            with gzip.open('./data/mnist.pkl.gz', 'rb') as f:
                train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
            return train_set[0], test_set[0]
    
        def reshape_for_print(self, data):
            gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
            imh = imw = np.sqrt(data.shape[1]).astype(np.int32)
            timg = data.reshape(gh, gw, imh, imw).swapaxes(1, 2)
            timg = timg.reshape(gh * imh, gw * imw)
            return timg
    
        def show_all_imgs(self, data):
            timg = self.reshape_for_print(data)
            plt.imshow(timg, cmap=plt.cm.gray)
            plt.show()
    
    0 讨论(0)
  • 2021-02-06 13:15

    The image is blurry due to interpolation. To prevent blurring in matplotlib, call imshow with keyword interpolation='nearest':

    plt.imshow(img.T, interpolation='nearest')
    

    Also, it appears that your x and y axes are being swapped when you use the transpose so you may want to display like this instead:

    plt.imshow(np.transpose(img, (1, 2, 0)), interpolation='nearest')
    
    0 讨论(0)
提交回复
热议问题