Why CIFAR-10 images are not displayed properly using matplotlib?

前端 未结 9 840
独厮守ぢ
独厮守ぢ 2021-02-06 12:28

From the training set I took a image(\'img\') of size (3,32,32). I have used plt.imshow(img.T). The image is not clear. Now changes I have to make to image(\'img\') to make it m

相关标签:
9条回答
  • 2021-02-06 12:50

    try using

    import matplotlib.pyplot as plt
    from scipy.misc import toimage
    plt.imshow(toimage(img))
    

    I am not 100% sure of how the code works, but I think that because the images are stored in floating point numpy arrays, the imshow() function has a difficult time mapping them to the right colors. By typecasting them to image using toimage() you convert them into proper image format that imshow() expects, i.e not an array but an image encoded as .png or .jpg.

    This code works for me every time I want to display images in python.

    0 讨论(0)
  • 2021-02-06 12:53

    This file reads the cifar10 dataset and plots individual images using matplotlib.

    import _pickle as pickle
    import argparse
    import numpy as np
    import os
    import matplotlib.pyplot as plt
    
    cifar10 = "./cifar-10-batches-py/"
    
    parser = argparse.ArgumentParser("Plot training images in cifar10 dataset")
    parser.add_argument("-i", "--image", type=int, default=0, 
                        help="Index of the image in cifar10. In range [0, 49999]")
    args = parser.parse_args()
    
    
    def unpickle(file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
        return dict
    
    def cifar10_plot(data, meta, im_idx=0):
        im = data[b'data'][im_idx, :]
    
        im_r = im[0:1024].reshape(32, 32)
        im_g = im[1024:2048].reshape(32, 32)
        im_b = im[2048:].reshape(32, 32)
    
        img = np.dstack((im_r, im_g, im_b))
    
        print("shape: ", img.shape)
        print("label: ", data[b'labels'][im_idx])
        print("category:", meta[b'label_names'][data[b'labels'][im_idx]])         
    
        plt.imshow(img) 
        plt.show()
    
    
    def main():
        batch = (args.image // 10000) + 1
        idx = args.image - (batch-1)*10000
    
        data = unpickle(os.path.join(cifar10, "data_batch_" + str(batch)))
        meta = unpickle(os.path.join(cifar10, "batches.meta"))
    
        cifar10_plot(data, meta, im_idx=idx)
    
    
    if __name__ == "__main__":
        main()
    
    0 讨论(0)
  • 2021-02-06 12:56

    Make sure you don't normalize your dataset when you want to display the image.

    Example :

    The loader...

    import torch
    from torchvision import datasets, transforms
    import matplotlib.pyplot as plt
    
    
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('../data', train=True, download=True,
                         transform=transforms.Compose([
                             transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(),
                            #  transforms.Normalize(
                            #      (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                         ])),
        batch_size=64, shuffle=True)
    

    The code that shows the image...

    img = next(iter(train_loader))[0][0]
    plt.imshow(transforms.ToPILImage()(img))
    

    Normalized

    Normalized

    Wihtout normalization

    Not normalized

    0 讨论(0)
  • 2021-02-06 12:59

    I made a function to plot the RGB image from a row in the CIFAR10 dataset.The image will be blurry at best since the original size of the image is very small (32px X 32px).

    sample image

    def unpickle(file):
        with open(file, 'rb') as fo:
            dict1 = pickle.load(fo, encoding='bytes')
        return dict1
    
    pd_tr = pd.DataFrame()
    tr_y = pd.DataFrame()
    
    for i in range(1,6):
        data = unpickle('data/data_batch_' + str(i))
        pd_tr = pd_tr.append(pd.DataFrame(data[b'data']))
        tr_y = tr_y.append(pd.DataFrame(data[b'labels']))
        pd_tr['labels'] = tr_y
    
    tr_x = np.asarray(pd_tr.iloc[:, :3072])
    tr_y = np.asarray(pd_tr['labels'])
    ts_x = np.asarray(unpickle('data/test_batch')[b'data'])
    ts_y = np.asarray(unpickle('data/test_batch')[b'labels'])    
    labels = unpickle('data/batches.meta')[b'label_names']
    
    def plot_CIFAR(ind):
        arr = tr_x[ind]
        sc_dpi = 157.35
        R = arr[0:1024].reshape(32,32)/255.0
        G = arr[1024:2048].reshape(32,32)/255.0
        B = arr[2048:].reshape(32,32)/255.0
    
        img = np.dstack((R,G,B))
        title = re.sub('[!@#$b]', '', str(labels[tr_y[ind]]))
        fig = plt.figure(figsize=(3,3))
        ax = fig.add_subplot(111)
        ax.imshow(img,interpolation='bicubic')
        ax.set_title('Category = '+ title,fontsize =15)
    
    plot_CIFAR(4)
    
    0 讨论(0)
  • 2021-02-06 12:59

    code result is: Try below code.

    I found a very useful link about visualization of mnist and cifar images. You can find codes for various images : https://machinelearningmastery.com/how-to-load-and-visualize-standard-computer-vision-datasets-with-keras/ cifar10 image code is below: It works well. Image is above.

    # example of loading the cifar10 dataset
    from matplotlib import pyplot
    from keras.datasets import cifar10
    # load dataset
    (trainX, trainy), (testX, testy) = cifar10.load_data()
    # summarize loaded dataset
    print('Train: X=%s, y=%s' % (trainX.shape, trainy.shape))
    print('Test: X=%s, y=%s' % (testX.shape, testy.shape))
    # plot first few images
    for i in range(9):
        # define subplot
        pyplot.subplot(330 + 1 + i)
        # plot raw pixel data
        pyplot.imshow(trainX[i])
    # show the figure
    pyplot.show()
    
    0 讨论(0)
  • 2021-02-06 13:07

    Add 0.5:

    plt.imshow(np.transpose(img, (1, 2, 0)) + 0.5)
    
    0 讨论(0)
提交回复
热议问题