TensorFlow: getting all states from a RNN

后端 未结 2 1464
说谎
说谎 2021-02-07 23:07

How do you get all the hidden states from tf.nn.rnn() or tf.nn.dynamic_rnn() in TensorFlow? The API only gives me the final state.

The first al

2条回答
  •  情书的邮戳
    2021-02-07 23:35

    I've already created a PR here and it might help you deal with simple cases

    Let me briefly explain my implementation, so you can write your own version if you need. The main part is the modification of the _time_step function:

    def _time_step(time, output_ta_t, state, *args):
    

    The parameters remain the same except an extra *args is passed in. But why args? That's because I want to support tensorflow's customary behavior. You are able to return the final state only by simply ignoring the args parameter:

    if states_ta is not None:
        # If you want to return all states, set `args` to be `states_ta`
        loop_vars = (time, output_ta, state, states_ta)
    else:
        # If you want the final state only, ignore `args`
        loop_vars = (time, output_ta, state)
    

    How to make use of it?

    if args:
        args = tuple(
            ta.write(time, out) for ta, out in zip(args[0], [new_state])
        )
    

    In fact this is just a modification of the following (original) codes:

    output_ta_t = tuple(
        ta.write(time, out) for ta, out in zip(output_ta_t, output)
    )
    

    Now the args should contain all the states you want.

    After all the works done above, you can pick up the states (or the final state) with following codes:

    _, output_final_ta, *state_info = control_flow_ops.while_loop( ...
    

    and

    if states_ta is not None:
        final_state, states_final_ta = state_info
    else:
        final_state, states_final_ta = state_info[0], None
    

    Although I haven't tested it in complicated cases, it should work under 'simple' condition (here's my test cases)

提交回复
热议问题