Source code for mth5.timeseries.spectre.spectrogram

"""
Module contains a class that represents a spectrogram.
i.e. A 2D time series of Fourier coefficients with axes time and the other frequency.
The datasets are xarray/dataframe and are fundmentally multivariate.

"""

from typing import List, Literal, Optional, Tuple, Union

# Third-party imports
import pandas as pd
import xarray as xr

# Standard library imports
from loguru import logger

# Local imports
from mt_metadata.common.band import Band
from mt_metadata.processing.aurora.frequency_bands import FrequencyBands

from mth5.timeseries.xarray_helpers import covariance_xr, initialize_xrda_2d


[docs] class Spectrogram(object): """ Class to contain methods for STFT objects. TODO: Add OLS Z-estimates -- actually, these are properties of cross powers, not direct properties of spectrograms. TODO: Add Sims/Vozoff Z-estimates -- actually, these are properties of cross powers as well. **Note** Coherence is similarly, a property of cross powers. There are in fact, very few features that we would derive from an unaveraged spectrogram. Pretty much everything except statistical moments comes from cross powers. Development Notes: - The spectrogram class is fundamental to MT Processing, and normally appears during the STFT operation. - The extract_band method returns another Spectrogram, having the same time axis as the parent object, but only a slice of the frequency range. Both of these have in common that their frequency axes are uniformly spaced, delta-f, where delta-f is dictated by the time series sample rate and the FFT window lenght. - There is a sibling spectral-time-series container that should be considered. Call it for now, a FrequencyChunkedSpectrogram (or an AveragedSpectrogram). This is a container similar to spectrogram, but the frequencies are not uniformly spaced (instead, often logartihmically spaced), they are made from one or more (possibly multivariate) spectrograms, and a FrequencyBands object. The key difference is that in a FrequencyChunkedSpectrogram object has a non-uniform spaced the Frequency axis which was prescribed by a metadata object. Most features, as well as TFs have a FrequencyChunkedSpectrogram representation, where final TFs are just time-averaged a FrequencyChunkedSpectrograms. TODO: consider factoring a simpler class that does not make the uniform frequency axis assumption. Spectrogram would extend this class and add the _frequency_increment property (taken from the differece in the first two values of the frequency axis), and num_harmoincs in band. """ def __init__(self, dataset: Optional[xr.Dataset] = None): """ Constructor. """ self._dataset = dataset self._frequency_increment = None self._frequency_band = None def _lowest_frequency(self): # -> float: pass # return self.dataset.frequency.min def _highest_frequency(self): # -> float: pass # return self.dataset.frequency.max def __str__(self) -> str: """Returns a Description of frequency coverage""" if self.dataset is None: return "Dataless Spectrogram" intro = "Spectrogram:" frequency_coverage = ( f"{self.dataset.sizes['frequency']} harmonics, {self.frequency_increment}Hz spaced \n" f" from {self.dataset.frequency.data[0]} to {self.dataset.frequency.data[-1]} Hz." ) time_coverage = f"\n{self.dataset.sizes['time']} Time observations" time_coverage = f"{time_coverage} \nStart: {self.dataset.time.data[0]}" time_coverage = f"{time_coverage} \nEnd: {self.dataset.time.data[-1]}" channel_coverage = list(self.dataset.data_vars.keys()) channel_coverage = "\n".join(channel_coverage) channel_coverage = f"\nChannels present: \n{channel_coverage}" return ( intro + "\n" + frequency_coverage + "\n" + time_coverage + "\n" + channel_coverage ) def __repr__(self) -> str: return self.__str__() @property
[docs] def dataset(self): """returns the underlying xarray data""" return self._dataset
@property
[docs] def dataarray(self): """returns the underlying xarray data""" return self._dataset.to_array()
@property
[docs] def time_axis(self): """returns the time axis of the underlying xarray""" return self.dataset.time
@property
[docs] def frequency_axis(self): """returns the frequency axis of the underlying xarray""" return self.dataset.frequency
@property
[docs] def frequency_band(self) -> Band: """returns a frequency band object representing the spectrograms band (assumes continuous)""" if self._frequency_band is None: band = Band( frequency_min=self.frequency_axis.min().item(), frequency_max=self.frequency_axis.max().item(), ) self._frequency_band = band return self._frequency_band
@property
[docs] def frequency_increment(self): """ returns the "delta f" of the frequency axis - assumes uniformly sampled in frequency domain """ if self._frequency_increment is None: frequency_axis = self.dataset.frequency try: self._frequency_increment = ( frequency_axis.data[1] - frequency_axis.data[0] ) except IndexError: msg = "frequency increment for spectrogram with frequency axis of length 1 is not defined" logger.debug(msg) self._frequency_increment = "undefined" return self._frequency_increment
[docs] def num_harmonics_in_band(self, frequency_band: Band, epsilon: float = 1e-7) -> int: """ Returns the number of harmonics within the frequency band in the underlying dataset Parameters ---------- frequency_band stft_obj Returns ------- num_harmonics: int The number of harmonics in the underlying dataset within the given frequency band. """ extracted_spectrogram = self.extract_band(frequency_band, epsilon=epsilon) num_harmonics = len(extracted_spectrogram.frequency_axis) return num_harmonics
[docs] def extract_band( self, frequency_band: Band, channels: Optional[list] = None, epsilon: Optional[float] = None, ): """ Returns another instance of Spectrogram, with the frequency axis reduced to the input band. Parameters ---------- frequency_band channels Returns ------- spectrogram: aurora.time_series.spectrogram.Spectrogram Returns a Spectrogram object with only the extracted band for a dataset """ # Set epsilon to a floating point value if it was not provided # self.frequency_increment / 2.0 is the legacy default if epsilon is None: epsilon = self.frequency_increment / 2.0 extracted_band_dataset = extract_band( frequency_band, self.dataset, channels=channels, epsilon=epsilon ) # Drop NaN values along the frequency dimension # extracted_band_dataset = extracted_band_dataset.dropna(dim='frequency', how='any') spectrogram = Spectrogram(dataset=extracted_band_dataset) return spectrogram
[docs] def cross_power_label(self, ch1: str, ch2: str, join_char: str = "_"): """joins channel names with join_char""" return f"{ch1}{join_char}{ch2}"
def _validate_frequency_bands( self, frequency_bands: FrequencyBands, strict: bool = True, ): """ Make sure that the frequency bands passed are relevant. If not, drop and warn. :param frequency_bands: A collection of bands :type frequency_bands: FrequencyBands :param strict: If true, band must be contained to be valid, if false, any overlapping band is valid. :type strict: bool :return: """ if strict: valid_bands = [ x for x in frequency_bands.bands() if self.frequency_band.contains(x) ] else: valid_bands = [ x for x in frequency_bands.bands() if self.frequency_band.overlaps(x) ] lower_bounds = [x.lower_bound for x in valid_bands] upper_bounds = [x.upper_bound for x in valid_bands] valid_frequency_bands = FrequencyBands( pd.DataFrame( data={ "lower_bound": lower_bounds, "upper_bound": upper_bounds, } ) ) # TODO: If strict, only take bands that are contained return valid_frequency_bands
[docs] def cross_powers( self, frequency_bands: FrequencyBands, channel_pairs: Optional[List[Tuple[str, str]]] = None, ): """ Compute cross powers between channel pairs for given frequency bands. TODO: Add handling for case when band in frequency_bands is not contained in self.frequencies. Parameters ---------- frequency_bands : FrequencyBands The frequency bands to compute cross powers for. Each element of this iterable tells the lower and upper bounds of the cross-power calculation bands. These may become objects with information about tapers as ewwll. channel_pairs : list of tuples, optional List of channel pairs to compute cross powers for. If None, all possible pairs will be used. Returns ------- xr.Dataset Dataset containing cross powers for all channel pairs. Each variable is named by the channel pair (e.g. 'ex_hy') and contains a 2D array with dimensions (frequency, time). All variables share common frequency and time coordinates. """ from itertools import combinations_with_replacement valid_frequency_bands = self._validate_frequency_bands(frequency_bands) # If no channel pairs specified, use all possible pairs if channel_pairs is None: channels = list(self.dataset.data_vars.keys()) channel_pairs = list(combinations_with_replacement(channels, 2)) # Create variable names from channel pairs var_names = [self.cross_power_label(ch1, ch2) for ch1, ch2 in channel_pairs] # Initialize a single multi-channel 2D xarray xpower_array = initialize_xrda_2d( var_names, coords={ "frequency": frequency_bands.band_centers(), "time": self.dataset.time.values, }, dtype=complex, ) # Compute cross powers for each band and channel pair for band in valid_frequency_bands.bands(): # Extract band data band_data = self.extract_band(band).dataset # Compute cross powers for each channel pair for ch1, ch2 in channel_pairs: label = self.cross_power_label(ch1, ch2) # Always compute as ch1 * conj(ch2) xpower = (band_data[ch1] * band_data[ch2].conj()).mean(dim="frequency") # Store the cross power xpower_array.loc[ dict( frequency=band.center_frequency, variable=label, time=slice(None), ) ] = xpower return xpower_array
[docs] def covariance_matrix( self, band_data: Optional["Spectrogram"] = None, method: str = "numpy_cov" ) -> xr.DataArray: """ TODO: Add tests for this WIP Work-in-progress method Compute full covariance matrix for spectrogram data. For complex-valued data, the result is a Hermitian matrix where: - diagonal elements are real-valued variances - off-diagonal element [i,j] is E[ch_i * conj(ch_j)] - off-diagonal element [j,i] is the complex conjugate of [i,j] Parameters ---------- band_data : Spectrogram, optional If provided, compute covariance for this data If None, use the full spectrogram method : str Computation method. Currently only supports 'numpy_cov' Returns ------- xr.DataArray Hermitian covariance matrix with proper channel labeling For channels i,j: matrix[i,j] = E[ch_i * conj(ch_j)] """ data = band_data or self flat_data = data.flatten(chunk_by="time") if method == "numpy_cov": # Convert to DataArray for covariance_xr stacked = flat_data.to_array(dim="variable") return covariance_xr(stacked) else: raise ValueError(f"Unknown method: {method}")
def _get_all_channel_pairs(self) -> List[Tuple[str, str]]: """Get all unique channel pairs (upper triangle)""" channels = list(self.dataset.data_vars.keys()) pairs = [] for i, ch1 in enumerate(channels[:-1]): for ch2 in channels[i + 1 :]: pairs.append((ch1, ch2)) return pairs
[docs] def flatten(self, chunk_by: Literal["time", "frequency"] = "time") -> xr.Dataset: """ Reshape the 2D spectrogram into a 1D flattened xarray (time-chunked by default). Parameters ---------- chunk_by: Literal["time", "frequency"] Reshaping the 2D spectrogram can be done two ways, (basically "row-major", or column-major). In xarray, but we either keep frequency constant and iterate over time, or keep time constant and iterate over frequency (in the inner loop). Returns ------- xarray.Dataset : The dataset from the band spectrogram, stacked. Development Notes: The flattening used in tf calculation by default is opposite to here dataset.stack(observation=("frequency", "time")) However, for feature extraction, it may make sense to swap the order: xrds = band_spectrogram.dataset.stack(observation=("time", "frequency")) This is like chunking into time windows and allows individual features to be computed on each time window -- if desired. Still need to split the time series though--Splitting to time would be a reshape by (last_freq_index-first_freq_index). Using pure xarray this may not matter but if we drop down into numpy it could be useful. """ if chunk_by == "time": observation = ("time", "frequency") elif chunk_by == "frequency": observation = ("frequency", "time") else: msg = f"Invalid argument chunk_by={chunk_by}, must be one of ['time', 'frequency']" logger.error(msg) raise ValueError(msg) return self.dataset.stack(observation=observation)
[docs] def extract_band( frequency_band: Band, fft_obj: Union[xr.Dataset, xr.DataArray], channels: Optional[list] = None, epsilon: float = 1e-7, ) -> Union[xr.Dataset, xr.DataArray]: """ Extracts a frequency band from xr.DataArray representing a spectrogram. TODO: Update variable names. Development Notes: Base dataset object should be a xr.DataArray (not xr.Dataset) - drop=True does not play nice with h5py and Dataset, results in a type error. File "stringsource", line 2, in h5py.h5r.Reference.__reduce_cython__ TypeError: no default __reduce__ due to non-trivial __cinit__ However, it works OK with DataArray. Parameters ---------- frequency_band: mt_metadata.common.band.Band Specifies interval corresponding to a frequency band fft_obj: xarray.core.dataset.Dataset Short-time-Fourier-transformed datat. Can be multichannel. channels: list Channel names to extract. epsilon: float Use this when you are worried about missing a frequency due to round off error. This is in general not needed if we use a df/2 pad around true harmonics. Returns ------- extracted_band: xr.DataArray The frequencies within the band passed into this function """ cond1 = fft_obj.frequency >= frequency_band.lower_bound - epsilon cond2 = fft_obj.frequency <= frequency_band.upper_bound + epsilon try: extracted_band = fft_obj.where(cond1 & cond2, drop=True) except TypeError: # see Note #1 tmp = fft_obj.to_array() extracted_band = tmp.where(cond1 & cond2, drop=True) extracted_band = extracted_band.to_dataset("variable") if channels: extracted_band = extracted_band[channels] if len(extracted_band.frequency) == 0: msg = ( f"Frequency band {frequency_band} does not overlap with the frequencies " f"of the input dataset. Frequencies in dataset are: {fft_obj.frequency.values}. " "Skipping band extraction. Consider reforming the bands." ) logger.warning(msg) return extracted_band