Friday 8 March 2024

Plotting a Spectrogram In Python, Using Numpy and Matplotlib

When performing frequency domain (FFT) based processing it is often useful to display a spectrogram of the frequency domain results. While there is a very good SciPy spectrogram function, this takes time domain data and does all of the clever stuff. However if you are processing data in the frequency domain you often just want to build the spectrogram dataset and keep appending FFT results to it.

A spectrogram is a 3D plot, with the following configuration:

  • Time is on the X axis
  • Frequency is on the Y axis
  • Frequency magnitude is shown in colour

This program will use the Matplotlib function imshow() to display the spectrogram.

There are two key tricks to using imshow() for this purpose:

  • Rotate the dataset so that time is on the x-axis and frequency is on the y-axis
  • Scale the x and y axes labels to correctly show time and frequency
  • Remove the second half of the FFT results - once the magnitude of the FFT result has been calculated, the two halves of the result are mirror images so we can discard the upper half.

Here's the code:

import matplotlib.pyplot as plt
import numpy as np
from scipy import signal

# Global Configuration
plotXAxisSecondsFlag = True                             # Set to True to Plot time in seconds, False to plot time in samples
plotYAxisHzFlag = True                                  # Set to True to Plot frequency in Hz, False to plot frequency in bins

Fs = 10000                                              # Sampling Frequency (Hz)
timePeriod = 10                                         # Time period in seconds
sampleLength = Fs*timePeriod

sinusoidFrequency = 1000                                # Frequency of sine wave (Hz)

fftLength = 256                                         # Length of the FFT
halfFftLength = fftLength >> 1

window = np.hanning(fftLength)

time = np.arange(sampleLength) / float(Fs)              # Generate sinusoid + harmonic with half the magnitude
x = np.sin(2*np.pi*sinusoidFrequency*time) * ((time[-1] - time)/1000) # Decreases in amplitude over time
x += 0.5 * np.sin(2*2*np.pi*sinusoidFrequency*time) * (time/1000)     # Increases in amplitude over time

# Add FFT frames to the spectrogram list - Note, we use a Python list here becasue it is very easy to append to
spectrogramDataset = []

i = 0
while i < (len(x) - fftLength):                         # Step through whole dataset
  x_discrete = x[i:i + fftLength]                       # Extract time domain frame
  x_discrete = x_discrete * window                      # Apply window function
  x_frequency = np.abs(np.fft.fft(x_discrete))          # Perform FFT
  x_frequency = x_frequency[:halfFftLength]             # Remove the redundant second half of the FFT result
  spectrogramDataset.append(x_frequency)                # Append frequency response to spectrogram dataset
  i = i + fftLength

# Plot the spectrogram
spectrogramDataset = np.asarray(spectrogramDataset)     # Convert to Numpy array then rotate and flip the dataset
spectrogramDataset = np.rot90(spectrogramDataset)
z_min = np.min(spectrogramDataset)
z_max = np.max(spectrogramDataset)
plt.imshow(spectrogramDataset, cmap='gnuplot2', vmin = z_min, vmax = z_max, interpolation='nearest', aspect='auto')
freqbins, timebins = np.shape(spectrogramDataset)
xlocs = np.float32(np.linspace(0, timebins-1, 8))
if plotXAxisSecondsFlag == True:
  plt.xticks(xlocs, ["%.02f" % (i*spectrogramDataset.shape[1]*fftLength/(timebins*Fs)) for i in xlocs]) # X axis is time (seconds)
  plt.xlabel('Time (s)')
  plt.xticks(xlocs, ["%.02f" % (i*spectrogramDataset.shape[1]*fftLength/timebins) for i in xlocs])      # X axis is samples
  plt.xlabel('Time (Samples)')

ylocs = np.int16(np.round(np.linspace(0, freqbins-1, 11)))
if (plotYAxisHzFlag == True):
  plt.yticks(ylocs, ["%.02f" % (((halfFftLength-i-1)*Fs)/fftLength) for i in ylocs])  # Y axis is Hz
  plt.ylabel('Frequency (Hz)')
  plt.yticks(ylocs, ["%d" % int(halfFftLength-i-1) for i in ylocs])                   # Y axis is Bins
  plt.ylabel('Frequency (Bins)')