Invertible STFT and ISTFT in Python

后端 未结 10 1942
自闭症患者
自闭症患者 2020-12-02 07:57

Is there any general-purpose form of short-time Fourier transform with corresponding inverse transform built into SciPy or NumPy or whatever?

There\'s the pyplot

相关标签:
10条回答
  • 2020-12-02 08:34

    librosa.core.stft and istft look pretty similar to what I was looking for, though they didn't exist at the time:

    librosa.core.stft(y, n_fft=2048, hop_length=None, win_length=None, window=None, center=True, dtype=<type 'numpy.complex64'>)

    They don't invert exactly, though; the ends are tapered.

    0 讨论(0)
  • 2020-12-02 08:36

    Here is the STFT code that I use. STFT + ISTFT here gives perfect reconstruction (even for the first frames). I slightly modified the code given here by Steve Tjoa : here the magnitude of the reconstructed signal is the same as that of the input signal.

    import scipy, numpy as np
    
    def stft(x, fftsize=1024, overlap=4):   
        hop = fftsize / overlap
        w = scipy.hanning(fftsize+1)[:-1]      # better reconstruction with this trick +1)[:-1]  
        return np.array([np.fft.rfft(w*x[i:i+fftsize]) for i in range(0, len(x)-fftsize, hop)])
    
    def istft(X, overlap=4):   
        fftsize=(X.shape[1]-1)*2
        hop = fftsize / overlap
        w = scipy.hanning(fftsize+1)[:-1]
        x = scipy.zeros(X.shape[0]*hop)
        wsum = scipy.zeros(X.shape[0]*hop) 
        for n,i in enumerate(range(0, len(x)-fftsize, hop)): 
            x[i:i+fftsize] += scipy.real(np.fft.irfft(X[n])) * w   # overlap-add
            wsum[i:i+fftsize] += w ** 2.
        pos = wsum != 0
        x[pos] /= wsum[pos]
        return x
    
    0 讨论(0)
  • 2020-12-02 08:39

    Here is my Python code, simplified for this answer:

    import scipy, pylab
    
    def stft(x, fs, framesz, hop):
        framesamp = int(framesz*fs)
        hopsamp = int(hop*fs)
        w = scipy.hanning(framesamp)
        X = scipy.array([scipy.fft(w*x[i:i+framesamp]) 
                         for i in range(0, len(x)-framesamp, hopsamp)])
        return X
    
    def istft(X, fs, T, hop):
        x = scipy.zeros(T*fs)
        framesamp = X.shape[1]
        hopsamp = int(hop*fs)
        for n,i in enumerate(range(0, len(x)-framesamp, hopsamp)):
            x[i:i+framesamp] += scipy.real(scipy.ifft(X[n]))
        return x
    

    Notes:

    1. The list comprehension is a little trick I like to use to simulate block processing of signals in numpy/scipy. It's like blkproc in Matlab. Instead of a for loop, I apply a command (e.g., fft) to each frame of the signal inside a list comprehension, and then scipy.array casts it to a 2D-array. I use this to make spectrograms, chromagrams, MFCC-grams, and much more.
    2. For this example, I use a naive overlap-and-add method in istft. In order to reconstruct the original signal the sum of the sequential window functions must be constant, preferably equal to unity (1.0). In this case, I've chosen the Hann (or hanning) window and a 50% overlap which works perfectly. See this discussion for more information.
    3. There are probably more principled ways of computing the ISTFT. This example is mainly meant to be educational.

    A test:

    if __name__ == '__main__':
        f0 = 440         # Compute the STFT of a 440 Hz sinusoid
        fs = 8000        # sampled at 8 kHz
        T = 5            # lasting 5 seconds
        framesz = 0.050  # with a frame size of 50 milliseconds
        hop = 0.025      # and hop size of 25 milliseconds.
    
        # Create test signal and STFT.
        t = scipy.linspace(0, T, T*fs, endpoint=False)
        x = scipy.sin(2*scipy.pi*f0*t)
        X = stft(x, fs, framesz, hop)
    
        # Plot the magnitude spectrogram.
        pylab.figure()
        pylab.imshow(scipy.absolute(X.T), origin='lower', aspect='auto',
                     interpolation='nearest')
        pylab.xlabel('Time')
        pylab.ylabel('Frequency')
        pylab.show()
    
        # Compute the ISTFT.
        xhat = istft(X, fs, T, hop)
    
        # Plot the input and output signals over 0.1 seconds.
        T1 = int(0.1*fs)
    
        pylab.figure()
        pylab.plot(t[:T1], x[:T1], t[:T1], xhat[:T1])
        pylab.xlabel('Time (seconds)')
    
        pylab.figure()
        pylab.plot(t[-T1:], x[-T1:], t[-T1:], xhat[-T1:])
        pylab.xlabel('Time (seconds)')
    

    STFT of 440 Hz sinusoid ISTFT of beginning of 440 Hz sinusoid ISTFT of end of 440 Hz sinusoid

    0 讨论(0)
  • 2020-12-02 08:42

    I'm a little late to this, but realised scipy has inbuilt istft function as of 0.19.0

    0 讨论(0)
提交回复
热议问题