Finding mean and standard deviation across image channels PyTorch

对着背影说爱祢 提交于 2020-05-27 04:41:06

问题


Say I have a batch of images in the form of tensors with dimensions (B x C x W x H) where B is the batch size, C is the number of channels in the image, and W and H are the width and height of the image respectively. I'm looking to use the transforms.Normalize() function to normalize my images with respect to the mean and standard deviation of the dataset across the C image channels, meaning that I want a resulting tensor in the form 1 x C. Is there a straightforward way to do this?

I tried torch.view(C, -1).mean(1) and torch.view(C, -1).std(1) but I get the error:

view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Edit

After looking into how view() works in PyTorch, I know realize why my approach doesn't work; however, I still can't figure out how to get the per-channel mean and standard deviation.


回答1:


You just need to rearrange batch tensor in a right way: from [B, C, W, H] to [B, C, W * H] by:

batch = batch.view(batch.size(0), batch.size(1), -1)

Here is complete usage example on random data:

Code:

import torch
from torch.utils.data import TensorDataset, DataLoader

data = torch.randn(64, 3, 28, 28)
labels = torch.zeros(64, 1)
dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size=8)

nimages = 0
mean = 0.
std = 0.
for batch, _ in loader:
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    std += batch.std(2).sum(0)

# Final step
mean /= nimages
std /= nimages

print(mean)
print(std)

Output:

tensor([-0.0029, -0.0022, -0.0036])
tensor([0.9942, 0.9939, 0.9923])



回答2:


Note that variances add, not standard deviations. See detailed explanation here: https://apcentral.collegeboard.org/courses/ap-statistics/classroom-resources/why-variances-add-and-why-it-matters

Here is the modified code:

nimages = 0
mean = 0.0
var = 0.0
for i_batch, batch_target in enumerate(trainloader):
    batch = batch_target[0]
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    var += batch.var(2).sum(0)

mean /= nimages
var /= nimages
std = torch.sqrt(var)

print(mean)
print(std)


来源:https://stackoverflow.com/questions/60101240/finding-mean-and-standard-deviation-across-image-channels-pytorch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!