Iterating over Torchtext.data.BucketIterator object throws AttributeError 'Field' object has no attribute 'vocab'

拜拜、爱过 提交于 2020-02-03 16:28:51

问题


When I try to look into a batch, by printing the next iteration of the BucketIterator object, the AttributeError is thrown.

tv_datafields=[("Tweet",TEXT), ("Anger",LABEL), ("Fear",LABEL), ("Joy",LABEL), ("Sadness",LABEL)]
train, vld = data.TabularDataset.splits(path="./data/", train="train.csv",validation="test.csv",format="csv", fields=tv_datafields)

train_iter, val_iter = BucketIterator.splits(
(train, vld),
batch_sizes=(64, 64),
device=-1,
sort_key=lambda x: len(x.Tweet),
sort_within_batch=False,
repeat=False
)
print(next(iter(train_dl)))

回答1:


I am not sure about the specific error you are getting but, in this case, you can iterate over a batch by using the following code:

for i in train_iter:
    print i.Tweet
    print i.Anger
    print i.Fear
    print i.Joy
    print i.Sadness

i.Tweet (also others) is a tensor of shape (input_data_length, batch_size).

So, to view a single batch data (lets say batch 0), you can do print i.Tweet[:,0].

Same goes for val_iter (and test_iter, if needed).



来源:https://stackoverflow.com/questions/51231852/iterating-over-torchtext-data-bucketiterator-object-throws-attributeerror-field

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!