How to get the filename of a sample from a DataLoader?

你离开我真会死。 提交于 2019-12-12 15:06:05

问题


I need to write a file with the result of the data test of a Convolutional Neural Network that I trained. The data include speech data collection. The file format needs to be "file name, prediction", but I am having a hard time to extract the file name. I load the data like this:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

and I am trying to write to the file as follows:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()

The problem with os.listdir(TESTH_DATA_PATH + "/all")[i] is that it is not synchronized with the loaded files order of test_loader. What can I do?


回答1:


Well, it depends on how your Dataset is implemented. For instance, in the torchvision.datasets.MNIST(...) case, you cannot retrieve the filename simply because there is no such thing as the filename of a single sample (MNIST samples are loaded in a different way).

As you did not show your Dataset implementation, I'll tell you how this could be done with the torchvision.datasets.ImageFolder(...) (or any torchvision.datasets.DatasetFolder(...)):

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()

You can see that the path of the file is retrieved during the __getitem__(self, index), especifically here.

If you implemented your own Dataset (and perhaps would like to support shuffle and batch_size > 1), then I would return the sample_fname on the __getitem__(...) call and do something like this:

for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]

This way you wouldn't need to care about shuffle. And if the batch_size is greater than 1, you would need to change the content of the loop for something more generic, e.g.:

f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("\n".join([
        ", ".join(x)
        for x in zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "\n")
f.close()



回答2:


In general case DataLoader is there to provide you the batches from the Dataset(s) it has inside.

AS @Barriel mentioned in case of single/multi-label classification problems, the DataLoader doesn't have image file name, just the tensors representing the images , and the classes / labels.

However, DataLoader constructor when loading objects can take small hings (together with the Dataset you may pack the targets/labels and the file names if you like)

This way, the DataLoader may somehow grab that what you need.



来源:https://stackoverflow.com/questions/56699048/how-to-get-the-filename-of-a-sample-from-a-dataloader

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!