How to create feature columns for TensorFlow classifier

余生长醉 提交于 2019-12-05 02:20:18

问题


I have a very simple dataset for binary classification in csv file which looks like this:

"feature1","feature2","label"
1,0,1
0,1,0
...

where the "label" column indicates class (1 is positive, 0 is negative). The number of features is actually pretty big but it doesn't matter for that question.

Here is how I read the data:

train = pandas.read_csv(TRAINING_FILE)
y_train, X_train = train['label'], train[['feature1', 'feature2']].fillna(0)

test = pandas.read_csv(TEST_FILE)
y_test, X_test = test['label'], test[['feature1', 'feature2']].fillna(0)

I want to run tensorflow.contrib.learn.LinearClassifier and tensorflow.contrib.learn.DNNClassifier on that data. For instance, I initialize DNN like this:

classifier = DNNClassifier(hidden_units=[3, 5, 3],
                               n_classes=2,
                               feature_columns=feature_columns, # ???
                               activation_fn=nn.relu,
                               enable_centered_bias=False,
                               model_dir=MODEL_DIR_DNN)

So how exactly should I create the feature_columns when all the features are also binary (0 or 1 are the only possible values)?

Here is the model training:

classifier.fit(X_train.values,
                   y_train.values,
                   batch_size=dnn_batch_size,
                   steps=dnn_steps)

The solution with replacing fit() parameters with the input function would also be great.

Thanks!

P.S. I'm using TensorFlow version 1.0.1


回答1:


You can directly use tf.feature_column.numeric_column :

feature_columns = [tf.feature_column.numeric_column(key = key) for key in X_train.columns]



回答2:


I've just found the solution and it's pretty simple:

feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(X_train)

Apparently infer_real_valued_columns_from_input() works well with categorical variables.



来源:https://stackoverflow.com/questions/42965371/how-to-create-feature-columns-for-tensorflow-classifier

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!