scheduled sampling in Tensorflow

后端 未结 3 999
长情又很酷
长情又很酷 2020-12-31 19:13

The newest Tensorflow api about seq2seq model has included scheduled sampling:

https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTraini

相关标签:
3条回答
  • 2020-12-31 19:49

    This might also help you. This is for the case where you want to do scheduled sampling at each decoding step separately.

    import tensorflow as tf
    import numpy as np
    from tensorflow.python.ops import array_ops
    from tensorflow.python.ops import gen_array_ops
    from tensorflow.python.ops import math_ops
    from tensorflow.python.ops.distributions import categorical
    from tensorflow.python.ops.distributions import bernoulli
    batch_size = 64
    vocab_size = 50000
    emb_dim = 128
    output = tf.get_variable('output', 
    initializer=tf.constant(np.random.rand(batch_size,vocab_size)))
    base_next_inputs = tf.get_variable('input', 
    initializer=tf.constant(np.random.rand(batch_size,emb_dim)))
    embedding = tf.get_variable('embedding', 
    initializer=tf.constant(np.random.rand(vocab_size,emb_dim)))
    select_sampler = bernoulli.Bernoulli(probs=0.99, dtype=tf.bool)
    select_sample = select_sampler.sample(sample_shape=batch_size, 
    seed=123)
    sample_id_sampler = categorical.Categorical(logits=output)
    sample_ids = array_ops.where(
        select_sample,
        sample_id_sampler.sample(seed=123),
        gen_array_ops.fill([batch_size], -1))
    
    where_sampling = math_ops.cast(
       array_ops.where(sample_ids > -1), tf.int32)
    where_not_sampling = math_ops.cast(
       array_ops.where(sample_ids <= -1), tf.int32)
    sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
    inputs_not_sampling = array_ops.gather_nd(base_next_inputs, 
         where_not_sampling)
    sampled_next_inputs = tf.nn.embedding_lookup(embedding, 
        sample_ids_sampling)
    base_shape = array_ops.shape(base_next_inputs)
    result1 = array_ops.scatter_nd(indices=where_sampling, 
       updates=sampled_next_inputs, shape=base_shape)
    result2 = array_ops.scatter_nd(indices=where_not_sampling, 
       updates=inputs_not_sampling, shape=base_shape)
    result = result1 + result2
    

    I used the tensorflow documentation code to make this example. https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/contrib/seq2seq/python/ops/helper.py

    0 讨论(0)
  • 2020-12-31 19:55

    I contacted the engineer behind this, and he responded:

    The output sampler either emits the raw rnn output or the raw ground truth at that time step. The embedding sampler treats the rnn output as logits of a distribution and either emits the embedding lookup of a sampled id from that categorical distribution or the raw ground truth at that time step.

    0 讨论(0)
  • 2020-12-31 19:55

    Here's a basic example of using ScheduledEmbeddingTrainingHelper, using TensorFlow 1.3 and some higher level tf.contrib APIs. It's a sequence2sequence model, where the decoder's initial hidden state is the final hidden state of the encoder. It shows only how to train on a single batch (and apparently the task is "reverse this sequence"). For actual training tasks, I suggest looking at tf.contrib.learn APIs such as learn_runner, Experiment and tf.estimator.Estimator.

    import tensorflow as tf
    import numpy as np
    from tensorflow.python.layers.core import Dense
    
    vocab_size = 7
    embedding_size = 5
    lstm_units = 10
    
    src_batch = np.array([[1, 2, 3], [4, 5, 6]])
    trg_batch = np.array([[3, 2, 1], [6, 5, 4]])
    
    # *_seq will have shape (2, 3), *_seq_len will have shape (2)
    source_seq = tf.placeholder(shape=(None, None), dtype=tf.int32)
    target_seq = tf.placeholder(shape=(None, None), dtype=tf.int32)
    source_seq_len = tf.placeholder(shape=(None,), dtype=tf.int32)
    target_seq_len = tf.placeholder(shape=(None,), dtype=tf.int32)
    
    # add Start of Sequence (SOS) tokens to each sequence
    batch_size, sequence_size = tf.unstack(tf.shape(target_seq))
    sos_slice = tf.zeros([batch_size, 1], dtype=tf.int32) # 0 = start of sentence token
    decoder_input = tf.concat([sos_slice, target_seq], axis=1)
    
    embedding_matrix = tf.get_variable(
        name="embedding_matrix",
        shape=[vocab_size, embedding_size],
        dtype=tf.float32)
    source_seq_embedded = tf.nn.embedding_lookup(embedding_matrix, source_seq) # shape=(2, 3, 5)
    decoder_input_embedded = tf.nn.embedding_lookup(embedding_matrix, decoder_input) # shape=(2, 4, 5)
    
    unused_encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
        tf.contrib.rnn.LSTMCell(lstm_units),
        source_seq_embedded,
        sequence_length=source_seq_len,
        dtype=tf.float32)
    
    # Decoder:
    # At each time step t and for each sequence in the batch, we get x_t by either
    #   (1) sampling from the distribution output_layer(t-1), or
    #   (2) reading from decoder_input_embedded.
    # We do (1) with probability sampling_probability and (2) with 1 - sampling_probability.
    # Using sampling_probability=0.0 is equivalent to using TrainingHelper (no sampling).
    # Using sampling_probability=1.0 is equivalent to doing inference,
    # where we don't supervise the decoder at all: output at t-1 is the input at t.
    sampling_prob = tf.Variable(0.0, dtype=tf.float32)
    helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
        decoder_input_embedded,
        target_seq_len,
        embedding_matrix,
        sampling_probability=sampling_prob)
    
    output_layer = Dense(vocab_size)
    decoder = tf.contrib.seq2seq.BasicDecoder(
        tf.contrib.rnn.LSTMCell(lstm_units),
        helper,
        encoder_state,
        output_layer=output_layer)
    
    outputs, state, seq_len = tf.contrib.seq2seq.dynamic_decode(decoder)
    loss = tf.contrib.seq2seq.sequence_loss(
        logits=outputs.rnn_output,
        targets=target_seq,
        weights=tf.ones(trg_batch.shape))
    
    train_op = tf.contrib.layers.optimize_loss(
        loss=loss,
        global_step=tf.contrib.framework.get_global_step(),
        optimizer=tf.train.AdamOptimizer,
        learning_rate=0.001)
    
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        _, _loss = session.run([train_op, loss], {
            source_seq: src_batch,
            target_seq: trg_batch,
            source_seq_len: [3, 3],
            target_seq_len: [3, 3],
            sampling_prob: 0.5
        })
        print("Loss: " + str(_loss))
    

    For ScheduledOutputTrainingHelper, I would expect to just swap out the helper and use:

    helper = tf.contrib.seq2seq.ScheduledOutputTrainingHelper(
        target_seq,
        target_seq_len,
        sampling_probability=sampling_prob)
    

    However this gives an error, since the LSTM cell expects a multidimensional input per timestep (of shape (batch_size, input_dims)). I will raise an issue in GitHub to find out if this is a bug, or there's some other way to use ScheduledOutputTrainingHelper.

    0 讨论(0)
提交回复
热议问题