问题
I am trying train an estimator with a generator, but I want to feed this estimator with a package of samples for each iteration. I show the code:
def _generator():
for i in range(100):
feats = np.random.rand(4,2)
labels = np.random.rand(4,1)
yield feats, labels
def input_func_gen():
shapes = ((4,2),(4,1))
dataset = tf.data.Dataset.from_generator(generator=_generator,
output_types=(tf.float32, tf.float32),
output_shapes=shapes)
dataset = dataset.batch(4)
# dataset = dataset.repeat(20)
iterator = dataset.make_one_shot_iterator()
features_tensors, labels = iterator.get_next()
features = {'x': features_tensors}
return features, labels
x_col = tf.feature_column.numeric_column(key='x', shape=(4,2))
es = tf.estimator.LinearRegressor(feature_columns=[x_col],model_dir=tf_data)
es = es.train(input_fn=input_func_gen,steps = None)
When I run this code, it raises this error:
raise ValueError(err.message)
ValueError: Dimensions must be equal, but are 2 and 3 for 'linear/head/labels/assert_equal/Equal' (op: 'Equal') with input shapes: [2], [3].
How do I have to call to this structure??
thx!!!
回答1:
The batch size is automatically computed and added to the tensors shapes by Tensorflow, so it doesn't have to be done manually. Your generator should also be defined to output single samples.
Assuming the 4
in position 0 of your shapes are for the batch size, then:
import tensorflow as tf
import numpy
def _generator():
for i in range(100):
feats = numpy.random.rand(2)
labels = numpy.random.rand(1)
yield feats, labels
def input_func_gen():
shapes = ((2),(1))
dataset = tf.data.Dataset.from_generator(generator=_generator,
output_types=(tf.float32, tf.float32),
output_shapes=shapes)
dataset = dataset.batch(4)
# dataset = dataset.repeat(20)
iterator = dataset.make_one_shot_iterator()
features_tensors, labels = iterator.get_next()
features = {'x': features_tensors}
return features, labels
x_col = tf.feature_column.numeric_column(key='x', shape=(2))
es = tf.estimator.LinearRegressor(feature_columns=[x_col])
es = es.train(input_fn=input_func_gen,steps = None)
来源:https://stackoverflow.com/questions/49673602/train-tensorflow-model-with-estimator-from-generator