Source code for mth5.timeseries.spectre.helpers

"""
This is a placeholder module for functions that are used in testing and development of spectrograms.
"""

import pathlib
from typing import List, Literal, Optional, Union

import xarray as xr
from loguru import logger
from mt_metadata.processing.fourier_coefficients import Decimation as FCDecimation
from mt_metadata.processing.fourier_coefficients.decimation import (
    fc_decimations_creator,
    get_degenerate_fc_decimation,
)

import mth5
from mth5.mth5 import MTH5
from mth5.processing.spectre.stft import run_ts_to_stft_scipy
from mth5.utils.helpers import path_or_mth5_object


[docs] GROUPBY_COLUMNS = ["survey", "station", "sample_rate"]
@path_or_mth5_object
[docs] def add_fcs_to_mth5( m: MTH5, fc_decimations: Optional[Union[str, list]] = None, groupby_columns: List[str] = GROUPBY_COLUMNS, ) -> None: """ Add Fourier Coefficient Levels ot an existing MTH5. TODO: This method currently loops the heirarcy of the h5, and then calls an operator. How about making a single table that represents the loop up front and then looping once that table instead of this nested loop business? We would need a function that takes as input the groupby_columns. **Notes:** - This module computes the FCs differently than the legacy aurora pipeline. It uses scipy.signal.spectrogram. There is a test in Aurora to confirm that there are equivalent if we are not using fancy pre-whitening. Parameters ---------- m: MTH5 object The mth5 file, open in append mode. fc_decimations: Optional[Union[str, list]] This specifies the scheme to use for decimating the time series when building the FC layer. None: Just use default (something like four decimation levels, decimated by 4 each time say.) String: Controlled Vocabulary, values are a work in progress, that will allow custom definition of the fc_decimations for some common cases. For example, say you have stored already decimated time series, then you want simply the zeroth decimation for each run, because the decimated time series live under another run container, and that will get its own FCs. This is experimental. List: (**UNTESTED**) -- This means that the user thought about the decimations that they want to create and is passing them explicitly. -- probably will need to be a dictionary actually, since this would get redefined at each sample rate. """ # Group the channel summary by survey, station, sample_rate channel_summary_df = m.channel_summary.to_dataframe() grouper = channel_summary_df.groupby(groupby_columns) logger.debug(f"Detected {len(grouper)} unique station-sample_rate instances") # loop over groups for ( survey, station, sample_rate, ), group in ( grouper ): # TODO: is there a way to use the groupby_columns var instead of this tuple? msg = f"\n\n\nsurvey: {survey}, station: {station}, sample_rate {sample_rate}" logger.info(msg) station_obj = m.get_station(station, survey) run_summary = station_obj.run_summary # Get the FC decimation schemes if not provided -- note that this depends only on sample rate fc_decimations = _fc_decimations_from_sample_rate( fc_decimations=fc_decimations, sample_rate=sample_rate ) # TODO: Make this a function that can be done using df.apply() for i_run_row, run_row in run_summary.iterrows(): logger.info( f"survey: {survey}, station: {station}, sample_rate {sample_rate}, i_run_row {i_run_row}" ) # Access Run run_obj = m.from_reference(run_row.hdf5_reference) # Set the time period: # TODO: Should this be over-writing time period if it is already there? for fc_decimation in fc_decimations: fc_decimation.time_period = run_obj.metadata.time_period # Access the data to Fourier transform runts = run_obj.to_runts( start=fc_decimation.time_period.start, end=fc_decimation.time_period.end, ) run_xrds = runts.dataset # access container for FCs fc_group = station_obj.fourier_coefficients_group.add_fc_group( run_obj.metadata.id ) # If timing corrections were needed they could go here, right before STFT # TODO: replace i_dec_level with ts_decimation.level in the following for i_dec_level, fc_decimation in enumerate(fc_decimations): ts_decimation = fc_decimation.time_series_decimation # Temporary check that i_dec_level and ts_decimation.level are the same try: assert i_dec_level == ts_decimation.level except: msg = "decimation level has unexpected value" raise ValueError(msg) if ts_decimation.level != 0: # Apply decimation target_sample_rate = run_xrds.sample_rate / ts_decimation.factor run_xrds.sps_filters.decimate(target_sample_rate=target_sample_rate) _add_spectrogram_to_mth5( fc_decimation=fc_decimation, run_obj=run_obj, run_xrds=run_xrds, fc_group=fc_group, ) return
def _fc_decimations_from_sample_rate( sample_rate: float, fc_decimations: Optional[Union[str, list]] = None, ) -> Union[str, list]: """ Helper function to get some fc_decimations. Really only seems to be used by add_fcs_to_mth5. Development Notes: This function is probably overslicing the add_fcs_to_mth5 function :return fc_decimations: This is an iterable of """ # Get the FC decimation schemes if not provided -- note that this depend only on sample rate if not fc_decimations: msg = "FC Decimations not supplied, creating defaults on the fly" logger.info(f"{msg}") fc_decimations = fc_decimations_creator( initial_sample_rate=sample_rate, time_period=None ) elif isinstance(fc_decimations, str): if fc_decimations == "degenerate": fc_decimations = get_degenerate_fc_decimation(sample_rate) return fc_decimations def _add_spectrogram_to_mth5( fc_decimation: FCDecimation, run_obj: mth5.groups.RunGroup, run_xrds: xr.Dataset, fc_group: mth5.groups.FCGroup, ) -> None: """ This function has been factored out of add_fcs_to_mth5. This is the most atomic level of adding FCs and may be useful as standalone method. Parameters ---------- fc_decimation : FCDecimation Metadata about how the decimation level is to be processed run_xrds : xarray.core.dataset.Dataset Time series to be converted to a spectrogram and stored in MTH5. Returns ------- run_xrds : xarray.core.dataset.Dataset pre-whitened time series """ # check if this decimation level yields a valid spectrogram if not fc_decimation.is_valid_for_time_series_length(run_xrds.time.shape[0]): logger.info( f"Decimation Level {fc_decimation.time_series_decimation.level} invalid, TS of {run_xrds.time.shape[0]} samples too short" ) return spectrogram = run_ts_to_stft_scipy(fc_decimation, run_xrds) stft_obj = calibrate_stft_obj(spectrogram.dataset, run_obj) # Pack FCs into h5 and update metadata fc_decimation_group: FCDecimationGroup = fc_group.add_decimation_level( f"{fc_decimation.time_series_decimation.level}", decimation_level_metadata=fc_decimation, ) fc_decimation_group.from_xarray( stft_obj, fc_decimation_group.metadata.decimation.sample_rate ) fc_decimation_group.update_metadata() fc_group.update_metadata() @path_or_mth5_object
[docs] def read_back_fcs( m: Union[MTH5, pathlib.Path, str], mode: str = "r", groupby_columns: List[str] = GROUPBY_COLUMNS, ) -> None: """ Loops over stations in the channel summary of input (m) grouping by common sample_rate. Then loop over the runs in the corresponding FC Group. Finally, within an fc_group, loop decimation levels and read data to xarray. Log info about the shape of the xarray. This is a helper function for tests. It was used as a sanity check while debugging the FC files, and also is a good example for how to access the data at each level for each channel. Development Notes: The Time axis of the FC array changes from decimation_level to decimation_level. The frequency axis will shape will depend on the window length that was used to perform STFT. This is currently storing all (positive frequency) fcs by default, but future versions can also have selected bands within an FC container. Parameters ---------- m: Union[MTH5, pathlib.Path, str] Either a path to an mth5, or an MTH5 object that the FCs will be read back from. mode: str The mode to open the MTH5 file in. Defualts to (r)ead only. """ channel_summary_df = m.channel_summary.to_dataframe() logger.debug(channel_summary_df) grouper = channel_summary_df.groupby(groupby_columns) for (survey, station, sample_rate), group in grouper: logger.info(f"survey: {survey}, station: {station}, sample_rate {sample_rate}") station_obj = m.get_station(station, survey) fc_groups = station_obj.fourier_coefficients_group.groups_list logger.info(f"FC Groups: {fc_groups}") for run_id in fc_groups: fc_group = station_obj.fourier_coefficients_group.get_fc_group(run_id) dec_level_ids = fc_group.groups_list for dec_level_id in dec_level_ids: dec_level = fc_group.get_decimation_level(dec_level_id) xrds = dec_level.to_xarray(["hx", "hy"]) msg = f"dec_level {dec_level_id}" msg = f"{msg} \n Time axis shape {xrds.time.data.shape}" msg = f"{msg} \n Freq axis shape {xrds.frequency.data.shape}" logger.debug(msg) return
[docs] def calibrate_stft_obj( stft_obj: xr.Dataset, run_obj: mth5.groups.RunGroup, units: Literal["MT", "SI"] = "MT", channel_scale_factors: Optional[dict] = None, ) -> xr.Dataset: """ Calibrates frequency domain data into MT units. Development Notes: The calibration often raises a runtime warning due to DC term in calibration response = 0. TODO: It would be nice to suppress this, maybe by only calibrating the non-dc terms and directly assigning np.nan to the dc component when DC-response is zero. Parameters ---------- stft_obj : xarray.core.dataset.Dataset Time series of Fourier coefficients to be calibrated run_obj : mth5.groups.master_station_run_channel.RunGroup Provides information about filters for calibration units : string usually "MT", contemplating supporting "SI" scale_factors : dict or None keyed by channel, supports a single scalar to apply to that channels data Useful for debugging. Should not be used in production and should throw a warning if it is not None Returns ------- stft_obj : xarray.core.dataset.Dataset Time series of calibrated Fourier coefficients """ for channel_id in stft_obj.keys(): channel = run_obj.get_channel(channel_id) channel_response = channel.channel_response if not channel_response.filters_list: msg = f"Channel {channel_id} with empty filters list detected" logger.warning(msg) if channel_id == "hy": msg = "Channel hy has no filters, try using filters from hx" logger.warning(msg) channel_response = run_obj.get_channel("hx").channel_response indices_to_flip = channel_response.get_indices_of_filters_to_remove( include_decimation=False, include_delay=False ) indices_to_flip = [ i for i in indices_to_flip if channel.metadata.filters[i].applied ] filters_to_remove = [channel_response.filters_list[i] for i in indices_to_flip] if not filters_to_remove: logger.warning("No filters to remove") calibration_response = channel_response.complex_response( stft_obj.frequency.data, filters_list=filters_to_remove ) if channel_scale_factors: try: channel_scale_factor = channel_scale_factors[channel_id] except KeyError: channel_scale_factor = 1.0 calibration_response /= channel_scale_factor if units == "SI": logger.warning("Warning: SI Units are not robustly supported issue #36") # TODO: FIXME Sometimes raises a runtime warning due to DC term in calibration response = 0 stft_obj[channel_id].data /= calibration_response return stft_obj