How can I generate and display a grid of images in PyTorch with plt.imshow and torchvision.utils.make_grid?
问题 I am trying to understand how torchvision interacts with mathplotlib to produce a grid of images. It's easy to generate images and display them iteratively: import torch import torchvision import matplotlib.pyplot as plt w = torch.randn(10,3,640,640) for i in range (0,10): z = w[i] plt.imshow(z.permute(1,2,0)) plt.show() However, displaying these images in a grid does not seem to be as straightforward. w = torch.randn(10,3,640,640) grid = torchvision.utils.make_grid(w, nrow=5) plt.imshow(grid