Let us start with an input that is a simple time series and try to build an autoencoder that simply fourier transforms then untransforms our data in keras.
If we try to
I stumbled upon this as I was trying to solve the same problem. You can make the transition lossless by wrapping tf.real
and tf.imag
into Lambda
layers (I'm using stft
because there's no real valued equivalent):
x = tf.keras.layers.Lambda(
lambda v: tf.signal.stft(
v,
frame_length=1024,
frame_step=256,
fft_length=1024,
), name='gen/FFTLayer')(inputs)
real = tf.keras.layers.Lambda(tf.real)(x)
imag = tf.keras.layers.Lambda(tf.imag)(x)
...
# transform real and imag either separately or by concatenating them in the feature space.
...
x = tf.keras.layers.Lambda(lambda x: tf.complex(x[0], x[1]))([real, imag])
x = tf.keras.layers.Lambda(
lambda v: tf.signal.inverse_stft(
v,
frame_length=1024,
frame_step=256,
fft_length=1024,
))(x)
The fft2d function in tensorflow 1.13.1 is broke.
I think you just need more Lambda
wrapping (using tf.keras
since that's what I have installed):
import numpy
import tensorflow as tf
K = tf.keras
inputs = K.Input(shape=(10, 8), name='main_input')
x = K.layers.Lambda(tf.spectral.rfft)(inputs)
decoded = K.layers.Lambda(tf.spectral.irfft)(x)
model = K.Model(inputs, decoded)
output = model(tf.ones([10, 8]))
with tf.Session():
print(output.eval())
The output of irfft
should be real, so probably no need to cast it. But if you do need to cast it (or in general combine operations in a Lambda
layer), I'd wrap that in a Python lambda: K.layers.Lambda(lambda v: tf.cast(tf.spectral.whatever(v), tf.float32))
For example if you know your intermediate values (between rfft
and irfft
) will have an imaginary component of zero, you can truncate that off:
import numpy
import tensorflow as tf
K = tf.keras
inputs = K.Input(shape=(10, 8), name='main_input')
x = K.layers.Lambda(lambda v: tf.real(tf.spectral.rfft(v)))(inputs)
decoded = K.layers.Lambda(
lambda v: tf.spectral.irfft(tf.complex(real=v, imag=tf.zeros_like(v))))(x)
model = K.Model(inputs, decoded)
output = model(tf.reshape(tf.range(80, dtype=tf.float32), [10, 8]))
with tf.Session():
print(output.eval())
Note that this isn't true for general sequences, since even real-valued inputs can have imaginary components once transformed. It works for the tf.ones
input above, but the tf.range
input gets mangled:
[[ 0. 4. 4. 4. 4. 4. 4. 4.]
[ 8. 12. 12. 12. 12. 12. 12. 12.]
[16. 20. 20. 20. 20. 20. 20. 20.]
[24. 28. 28. 28. 28. 28. 28. 28.]
[32. 36. 36. 36. 36. 36. 36. 36.]
[40. 44. 44. 44. 44. 44. 44. 44.]
[48. 52. 52. 52. 52. 52. 52. 52.]
[56. 60. 60. 60. 60. 60. 60. 60.]
[64. 68. 68. 68. 68. 68. 68. 68.]
[72. 76. 76. 76. 76. 76. 76. 76.]]
(Without the casting we get 0. through 79. reconstructed perfectly)
Just to add more to what's going on above for anyone who gets here from search engines. The following, contributed in this google group discussion, will run rfft then ifft with convolutions and other layers in between:
inputs = Input(shape=(10, 8), name='main_input')
x = Lambda(lambda v: tf.to_float(tf.spectral.rfft(v)))(inputs)
x = Conv1D(filters=5, kernel_size=3, activation='relu', padding='same')(x)
x = Lambda(lambda v: tf.to_float(tf.spectral.irfft(tf.cast(v, dtype=tf.complex64))))(x)
x = Flatten()(x)
output = Dense(1)(x)
model = Model(inputs, output)
model.summary()
It uses the same concepts as Allen's answer but the slight differences allow compatibility with intermediate convolutions.