How to encode string in tf.data.Dataset?

佐手、 提交于 2020-12-15 01:55:36

问题


So I am trying to encode a string in a tensorflow dataset in order to use it to train a pretrained RoBERTa model. The training_dataset is a tensorflow dataset made from a pandas dataframe that looks like this:

I used this dataframe to construct the tf.data.Dataset using:

features = ['OptionA', 'OptionB', 'OptionC']

training_dataset = (
    tf.data.Dataset.from_tensor_slices(
        (
            tf.cast(train_split[features].values, tf.string),
            tf.cast(train_split['Answer'].values, tf.int32)
        )
    )
)

Now I want to encode the 3 columns OptionA, OptionB and Option C using a RobertaTokenizer, which is defined by:

tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

I tried:

training_dataset = training_dataset.map(lambda x: tokenizer.encode(x))

But this gave me the error: "TypeError: () takes 1 positional argument but 2 were given" and I am not sure how to deal with this or how to state that I only want the first three columns to be encoded.

Any help would be appreciated!


回答1:


training_dataset has features and outputs, and in your map function, you're only using one variable. Try:

training_dataset = training_dataset.map(lambda x, y: (tokenizer.encode(x), y))


来源:https://stackoverflow.com/questions/65274777/how-to-encode-string-in-tf-data-dataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!