Why CIFAR-10 images are not displayed properly using matplotlib?

前端 未结 9 842
独厮守ぢ
独厮守ぢ 2021-02-06 12:28

From the training set I took a image(\'img\') of size (3,32,32). I have used plt.imshow(img.T). The image is not clear. Now changes I have to make to image(\'img\') to make it m

9条回答
  •  温柔的废话
    2021-02-06 13:09

    I have used the following code to show all CIFAR data as one big image. The code show the image, but if you want to save it and not be blurtry i sugest using plt.savefig(fname, format='png', dpi=1000)

    import numpy as np
    import matplotlib.pyplot as plt
    
    def reshape_and_print(self, cifar_data):
        # number of images in rows and columns
        rows = cols = np.sqrt(cifar_data.shape[0]).astype(np.int32)
        # Image hight and width. Divide by 3 because of 3 color channels
        imh = imw = np.sqrt(cifar_data.shape[1] // 3).astype(np.int32)
        # reshape to number of images X color channels X image size
        # transpose to color channels X number of images X image size
        timg = cifar_data.reshape(rows * cols, 3, imh * imh).transpose(1, 0, 2)
        # reshape to color channels X rows X cols X image hight X image with
        # swap axis to color channels X rows X image hight X cols X image with
        timg = timg.reshape(3, rows, cols, imh, imw).swapaxes(2, 3)
        # reshape to color channels X combined image hight X combined image with
        # transpose to combined image hight X combined image with X color channels
        timg = timg.reshape(3, rows * imh, cols * imw).transpose(1, 2, 0)
    
        plt.imshow(timg)
        plt.show()
    

    I made a quick data helper class that i used for a small test project, I hope is can be useful:

    import gzip
    import pickle
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    class DataSet(object):
    
        def __init__(self, seed=42, setsize=10000):
            self.seed = seed
            # set the seed for reproducability
            np.random.seed(seed)
            # load the data
            train_set, test_set = self.load_data()
            # self.split_data(train_set, valid_set, test_set)
            self.split_data(train_set, test_set, setsize)
    
        def split_data(self, data_set, test_set, split_size):
            permutation = np.random.permutation(data_set.shape[0])
            self.train = data_set[permutation[:split_size]]
            self.valid = data_set[permutation[split_size:split_size * 2]]
            self.test = test_set[:split_size]
    
        def reshape_for_print(self, data):
            raise NotImplemented
    
        def load_data(self):
            raise NotImplemented
    
        def show_all_imgs(self, data):
            raise NotImplemented
    
    
    class CIFAR(DataSet):
    
        def load_data(self):
            # try to load data
            with open('./data/cifar-100-python/train', 'rb') as f:
                data = pickle.load(f, encoding='latin1')
            train_set = data['data'].astype(np.float32) / 255.0
    
            with open('./data/cifar-100-python/test', 'rb') as f:
                data = pickle.load(f, encoding='latin1')
            test_set = data['data'].astype(np.float32) / 255.0
    
            return train_set, test_set
    
        def reshape_for_print(self, data):
            gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
            imh = imw = np.sqrt(data.shape[1] // 3).astype(np.int32)
            timg = data.reshape(gh * gw, 3, imh * imh).transpose(1, 0, 2)
            timg = timg.reshape(3, gh, gw, imh, imw).swapaxes(2, 3)
            timg = timg.reshape(3, gh * imh, gw * imw).transpose(1, 2, 0)
            return timg
    
        def show_all_imgs(self, data):
            timg = self.reshape_for_print(data)
            plt.imshow(timg)
            plt.show()
    
    
    class MNIST(DataSet):
    
        def load_data(self):
            # try to load data
            with gzip.open('./data/mnist.pkl.gz', 'rb') as f:
                train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
            return train_set[0], test_set[0]
    
        def reshape_for_print(self, data):
            gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
            imh = imw = np.sqrt(data.shape[1]).astype(np.int32)
            timg = data.reshape(gh, gw, imh, imw).swapaxes(1, 2)
            timg = timg.reshape(gh * imh, gw * imw)
            return timg
    
        def show_all_imgs(self, data):
            timg = self.reshape_for_print(data)
            plt.imshow(timg, cmap=plt.cm.gray)
            plt.show()
    

提交回复
热议问题