Could two tf.data.Dataset coexist and controled by tf.cond()

不问归期 提交于 2019-12-13 06:18:25

问题


I put two Dataset pipeline for train/test = 9:1 set in my Graph and the control the flow by a tf.cond. I encountered a problem that during the training the both pipelines are activated at each step. The testset ran out before the trainset as it has less during training.

OutOfRangeError (see above for traceback): End of sequence

First, nest the input pipeline in a function:

def input_pipeline(*args):
    ...
    # construct iterator
    it = batch.make_initializable_iterator()
    iter_init_op = it.initializer

    # get next img and label
    X_it, y_it = it.get_next()
    inputs = {'img': X_it, 'label': y_it, 'iterator_init_op': iter_init_op}
    return inputs

Then initiate:

train_input = input_pipeline(args)
test_input = input_pipeline(args)

In the model, we put a placeholder to fill condition with feed_dict which won't penalize the performance:

...
def f1(): return train_input
def f2(): return test_input
cond_pl = tf.placeholder(tf.string, name='cond_pl')
input = tf.cond(tf.equal(cond_pl, 'train'), lambda: f1(), lambda: f2())
...

Within the session:

for ep in range(nb_ep):
   ...
   for step in range(ep_len):
       print('step:{}\r'.format(step))
       try:
           sess.run([train_op], feed_dict={cond_pl: 'train'})
           if step % step_len == (step_len - 1):
               sess.run([test_op], feed_dict={cond_pl: 'test'})
       except tf.errors.OutOfRangeError:
           raise('drop the remainder')
   ...

How can I let the input pipeline get_next() get called only if the condition fits?


Snippet to play around updated based on @sharky answer:

def write_h5(args):
    x, is_training = args
    with h5py.File('./{}_{}.h5'.format('train' if is_training else 'test', x), 'w') as f:
        h = w = np.arange(-1, 1, 0.02)
        hh, _ = np.meshgrid(h, w)
        a = hh ** 2
        b = np.add(a + 1, np.random.randn(100, 100))  #do something and add gaussian noise
        f.create_dataset('X', shape=(100, 100), dtype='float32', data=a)
        f.create_dataset('y', shape=(100, 100), dtype='float32', data=b)


def input_pipeline(window_size, batch_size, is_train=True, ncores=mp.cpu_count()):
    flist = []
    for dirpath, _, fnames in os.walk('./'):
        for fname in fnames:
            if fname.startswith('train' if is_train else 'test') and fname.endswith('.h5'):
                print(fname)
                flist.append((os.path.abspath(os.path.join(dirpath, fname)), str(window_size)))
    f_len = len(flist)
    print(f_len)
    # init list of files
    batch = tf.data.Dataset.from_tensor_slices((tf.constant(flist)))
    batch = batch.map(_pyfn_wrapper, num_parallel_calls=ncores)
    batch = batch.shuffle(batch_size).batch(batch_size, drop_remainder=True).prefetch(ncores).repeat()

    # construct iterator
    it = batch.make_initializable_iterator()
    iter_init_op = it.initializer

    # get next img and label
    X_it, y_it = it.get_next()
    inputs = {'img': X_it, 'label': y_it, 'iterator_init_op': iter_init_op}
    return inputs, f_len


def _pyfn_wrapper(args):
    return tf.py_func(parse_h5,  #wrapped pythonic function
                      [args],
                      [tf.float32, tf.float32]  #[input, output] dtype
                      )

def parse_h5(args):
    name, window_size = args
    window_size = int(window_size.decode('utf-8'))
    with h5py.File(name, 'r') as f:
        X = f['X'][:].reshape(window_size, window_size, 1)
        y = f['y'][:].reshape(window_size, window_size, 1)
        return X, y


# init data
p = mp.Pool(mp.cpu_count())
p.map(write_h5, zip(range(9000), repeat(True)))
p.map(write_h5, zip(range(1000), repeat(False)))

# hparam
ep_len = 90
step_len = 9  # run test_op after 9 steps

# create tf.data.Dataset
train_input, train_len = input_pipeline(100, 5, is_train=True)
test_input, test_len = input_pipeline(100, 5, is_train=False)


# draw graph
def f1(): return train_input
def f2(): return test_input


cond_pl = tf.placeholder(tf.string, shape=None, name='cond_pl')
input = tf.cond(tf.equal(cond_pl, 'train'), lambda: f1(), lambda: f2())  # I thou

with tf.name_scope("Conv1"):
    W = tf.get_variable("W", shape=[3, 3, 1, 1],
                         initializer=tf.contrib.layers.xavier_initializer())
    b = tf.get_variable("b", shape=[1], initializer=tf.contrib.layers.xavier_initializer())
    layer1 = tf.nn.conv2d(input['img'], W, strides=[1, 1, 1, 1], padding='SAME') + b
    logits = tf.nn.relu(layer1)

loss = tf.reduce_mean(tf.losses.mean_squared_error(labels=input['label'], predictions=logits))
train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
test_op = print(loss)
#

# session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for ep in range(5):
        print('ep:{}'.format(ep))
        sess.run(input['iterator_init_op'], feed_dict={cond_pl: 'train'})
        sess.run(input['iterator_init_op'], feed_dict={cond_pl: 'test'})
        for step in range(ep_len):
            print('step:{}\r'.format(step))
            try:
                sess.run([train_op], feed_dict={cond_pl: 'train'})
                if step % step_len == (step_len - 1):
                    sess.run([test_op], feed_dict={cond_pl: 'test'})
            except tf.errors.OutOfRangeError:
                raise('drop the remainder')

回答1:


Consider example:

train = np.arange(90)
test = np.arange(10)

train_ds = tf.data.Dataset.from_tensor_slices(train).shuffle(10).batch(10).repeat()
test_ds = tf.data.Dataset.from_tensor_slices(test).shuffle(10).batch(10).repeat()

train_iterator = train_ds.make_initializable_iterator()
test_iterator = test_ds.make_initializable_iterator()

with tf.Session() as sess:
    sess.run(train_iterator.initializer)
    sess.run(test_iterator.initializer)
    for i in range(len(train) + 1):
        print(sess.run(train_iterator.get_next()))
        if i % 9 == 8:
            print(sess.run(test_iterator.get_next()))

Two datasets, two iterators, both initialized at startup. When i exceeds length of datasets, it starts repeating both of them because of repeat(). If it'll be called with num_epochs or not called at all, you'll get end of sequence. If for some reason you need/want to use cond, maybe this answer will help

How to use Tensorflow's tf.cond() with two different Dataset iterators without iterating both?



来源:https://stackoverflow.com/questions/55516484/could-two-tf-data-dataset-coexist-and-controled-by-tf-cond

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!