From the training set I took a image(\'img\') of size (3,32,32). I have used plt.imshow(img.T). The image is not clear. Now changes I have to make to image(\'img\') to make it m
Following prints 5X5 grid of random Cifar10 images. It isn't blurry, though not perfect either. Any suggestions welcome.
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from six.moves import cPickle
f = open('data/cifar10/cifar-10-batches-py/data_batch_1', 'rb')
datadict = cPickle.load(f,encoding='latin1')
f.close()
X = datadict["data"]
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("uint8")
Y = np.array(Y)
#Visualizing CIFAR 10
fig, axes1 = plt.subplots(5,5,figsize=(3,3))
for j in range(5):
for k in range(5):
i = np.random.choice(range(len(X)))
axes1[j][k].set_axis_off()
axes1[j][k].imshow(X[i:i+1][0])
I have used the following code to show all CIFAR data as one big image. The code show the image, but if you want to save it and not be blurtry i sugest using plt.savefig(fname, format='png', dpi=1000)
import numpy as np
import matplotlib.pyplot as plt
def reshape_and_print(self, cifar_data):
# number of images in rows and columns
rows = cols = np.sqrt(cifar_data.shape[0]).astype(np.int32)
# Image hight and width. Divide by 3 because of 3 color channels
imh = imw = np.sqrt(cifar_data.shape[1] // 3).astype(np.int32)
# reshape to number of images X color channels X image size
# transpose to color channels X number of images X image size
timg = cifar_data.reshape(rows * cols, 3, imh * imh).transpose(1, 0, 2)
# reshape to color channels X rows X cols X image hight X image with
# swap axis to color channels X rows X image hight X cols X image with
timg = timg.reshape(3, rows, cols, imh, imw).swapaxes(2, 3)
# reshape to color channels X combined image hight X combined image with
# transpose to combined image hight X combined image with X color channels
timg = timg.reshape(3, rows * imh, cols * imw).transpose(1, 2, 0)
plt.imshow(timg)
plt.show()
I made a quick data helper class that i used for a small test project, I hope is can be useful:
import gzip
import pickle
import numpy as np
import matplotlib.pyplot as plt
class DataSet(object):
def __init__(self, seed=42, setsize=10000):
self.seed = seed
# set the seed for reproducability
np.random.seed(seed)
# load the data
train_set, test_set = self.load_data()
# self.split_data(train_set, valid_set, test_set)
self.split_data(train_set, test_set, setsize)
def split_data(self, data_set, test_set, split_size):
permutation = np.random.permutation(data_set.shape[0])
self.train = data_set[permutation[:split_size]]
self.valid = data_set[permutation[split_size:split_size * 2]]
self.test = test_set[:split_size]
def reshape_for_print(self, data):
raise NotImplemented
def load_data(self):
raise NotImplemented
def show_all_imgs(self, data):
raise NotImplemented
class CIFAR(DataSet):
def load_data(self):
# try to load data
with open('./data/cifar-100-python/train', 'rb') as f:
data = pickle.load(f, encoding='latin1')
train_set = data['data'].astype(np.float32) / 255.0
with open('./data/cifar-100-python/test', 'rb') as f:
data = pickle.load(f, encoding='latin1')
test_set = data['data'].astype(np.float32) / 255.0
return train_set, test_set
def reshape_for_print(self, data):
gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
imh = imw = np.sqrt(data.shape[1] // 3).astype(np.int32)
timg = data.reshape(gh * gw, 3, imh * imh).transpose(1, 0, 2)
timg = timg.reshape(3, gh, gw, imh, imw).swapaxes(2, 3)
timg = timg.reshape(3, gh * imh, gw * imw).transpose(1, 2, 0)
return timg
def show_all_imgs(self, data):
timg = self.reshape_for_print(data)
plt.imshow(timg)
plt.show()
class MNIST(DataSet):
def load_data(self):
# try to load data
with gzip.open('./data/mnist.pkl.gz', 'rb') as f:
train_set, valid_set, test_set = pickle.load(f, encoding='latin1')
return train_set[0], test_set[0]
def reshape_for_print(self, data):
gh = gw = np.sqrt(data.shape[0]).astype(np.int32)
imh = imw = np.sqrt(data.shape[1]).astype(np.int32)
timg = data.reshape(gh, gw, imh, imw).swapaxes(1, 2)
timg = timg.reshape(gh * imh, gw * imw)
return timg
def show_all_imgs(self, data):
timg = self.reshape_for_print(data)
plt.imshow(timg, cmap=plt.cm.gray)
plt.show()
The image is blurry due to interpolation. To prevent blurring in matplotlib, call imshow
with keyword interpolation='nearest'
:
plt.imshow(img.T, interpolation='nearest')
Also, it appears that your x and y axes are being swapped when you use the transpose so you may want to display like this instead:
plt.imshow(np.transpose(img, (1, 2, 0)), interpolation='nearest')