问题
Say I have a batch of images in the form of tensors with dimensions (B x C x W x H) where B is the batch size, C is the number of channels in the image, and W and H are the width and height of the image respectively. I'm looking to use the transforms.Normalize()
function to normalize my images with respect to the mean and standard deviation of the dataset across the C image channels, meaning that I want a resulting tensor in the form 1 x C. Is there a straightforward way to do this?
I tried torch.view(C, -1).mean(1)
and torch.view(C, -1).std(1)
but I get the error:
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Edit
After looking into how view()
works in PyTorch, I know realize why my approach doesn't work; however, I still can't figure out how to get the per-channel mean and standard deviation.
回答1:
You just need to rearrange batch tensor in a right way: from [B, C, W, H]
to [B, C, W * H]
by:
batch = batch.view(batch.size(0), batch.size(1), -1)
Here is complete usage example on random data:
Code:
import torch
from torch.utils.data import TensorDataset, DataLoader
data = torch.randn(64, 3, 28, 28)
labels = torch.zeros(64, 1)
dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size=8)
nimages = 0
mean = 0.
std = 0.
for batch, _ in loader:
# Rearrange batch to be the shape of [B, C, W * H]
batch = batch.view(batch.size(0), batch.size(1), -1)
# Update total number of images
nimages += batch.size(0)
# Compute mean and std here
mean += batch.mean(2).sum(0)
std += batch.std(2).sum(0)
# Final step
mean /= nimages
std /= nimages
print(mean)
print(std)
Output:
tensor([-0.0029, -0.0022, -0.0036])
tensor([0.9942, 0.9939, 0.9923])
回答2:
Note that variances add, not standard deviations. See detailed explanation here: https://apcentral.collegeboard.org/courses/ap-statistics/classroom-resources/why-variances-add-and-why-it-matters
Here is the modified code:
nimages = 0
mean = 0.0
var = 0.0
for i_batch, batch_target in enumerate(trainloader):
batch = batch_target[0]
# Rearrange batch to be the shape of [B, C, W * H]
batch = batch.view(batch.size(0), batch.size(1), -1)
# Update total number of images
nimages += batch.size(0)
# Compute mean and std here
mean += batch.mean(2).sum(0)
var += batch.var(2).sum(0)
mean /= nimages
var /= nimages
std = torch.sqrt(var)
print(mean)
print(std)
来源:https://stackoverflow.com/questions/60101240/finding-mean-and-standard-deviation-across-image-channels-pytorch