Can we use tf.spectral fourier functions in keras?

前端 未结 4 1696
清酒与你
清酒与你 2021-02-06 16:56

Let us start with an input that is a simple time series and try to build an autoencoder that simply fourier transforms then untransforms our data in keras.

If we try to

4条回答
  •  逝去的感伤
    2021-02-06 17:26

    I think you just need more Lambda wrapping (using tf.keras since that's what I have installed):

    import numpy
    import tensorflow as tf
    K = tf.keras
    
    inputs = K.Input(shape=(10, 8), name='main_input')
    x = K.layers.Lambda(tf.spectral.rfft)(inputs)
    decoded = K.layers.Lambda(tf.spectral.irfft)(x)
    model = K.Model(inputs, decoded)
    output = model(tf.ones([10, 8]))
    with tf.Session():
      print(output.eval())
    

    The output of irfft should be real, so probably no need to cast it. But if you do need to cast it (or in general combine operations in a Lambda layer), I'd wrap that in a Python lambda: K.layers.Lambda(lambda v: tf.cast(tf.spectral.whatever(v), tf.float32))

    For example if you know your intermediate values (between rfft and irfft) will have an imaginary component of zero, you can truncate that off:

    import numpy
    import tensorflow as tf
    K = tf.keras
    
    inputs = K.Input(shape=(10, 8), name='main_input')
    x = K.layers.Lambda(lambda v: tf.real(tf.spectral.rfft(v)))(inputs)
    decoded = K.layers.Lambda(
        lambda v: tf.spectral.irfft(tf.complex(real=v, imag=tf.zeros_like(v))))(x)
    model = K.Model(inputs, decoded)
    output = model(tf.reshape(tf.range(80, dtype=tf.float32), [10, 8]))
    with tf.Session():
      print(output.eval())
    

    Note that this isn't true for general sequences, since even real-valued inputs can have imaginary components once transformed. It works for the tf.ones input above, but the tf.range input gets mangled:

    [[ 0.  4.  4.  4.  4.  4.  4.  4.]
     [ 8. 12. 12. 12. 12. 12. 12. 12.]
     [16. 20. 20. 20. 20. 20. 20. 20.]
     [24. 28. 28. 28. 28. 28. 28. 28.]
     [32. 36. 36. 36. 36. 36. 36. 36.]
     [40. 44. 44. 44. 44. 44. 44. 44.]
     [48. 52. 52. 52. 52. 52. 52. 52.]
     [56. 60. 60. 60. 60. 60. 60. 60.]
     [64. 68. 68. 68. 68. 68. 68. 68.]
     [72. 76. 76. 76. 76. 76. 76. 76.]]
    

    (Without the casting we get 0. through 79. reconstructed perfectly)

提交回复
热议问题