From 2fa5aca316f650cf420cd943b3c24f83eba03b4e Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 28 May 2024 14:50:46 +0200 Subject: [PATCH 01/38] Add filter storage --- .../spectral/epochs_multivariate.py | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index a8332c1b..f5a173ea 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -105,14 +105,18 @@ class _EpochMeanMultivariateConEstBase(_AbstractConEstBase): n_steps = None patterns = None + filters = None con_scores_dtype = np.float64 - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): + def __init__( + self, n_signals, n_cons, n_freqs, n_times, n_jobs=1, store_filters=False + ): self.n_signals = n_signals self.n_cons = n_cons self.n_freqs = n_freqs self.n_times = n_times self.n_jobs = n_jobs + self.store_filters = store_filters # include time dimension, even when unused for indexing flexibility if n_times == 0: @@ -190,9 +194,11 @@ class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): name: Optional[str] = None accumulate_psd = False - def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): + def __init__( + self, n_signals, n_cons, n_freqs, n_times, n_jobs=1, store_filters=False + ): super(_MultivariateCohEstBase, self).__init__( - n_signals, n_cons, n_freqs, n_times, n_jobs + n_signals, n_cons, n_freqs, n_times, n_jobs, store_filters ) def compute_con(self, indices, ranks, n_epochs=1): @@ -211,6 +217,10 @@ def compute_con(self, indices, ranks, n_epochs=1): self.patterns = np.full( (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), np.nan ) + if self.store_filters: + self.filters = np.full( + (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), np.nan + ) con_i = 0 for seed_idcs, target_idcs, seed_rank, target_rank in zip( @@ -436,20 +446,18 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i): alpha = V_seeds[times[:, None], freqs, :, w_seeds.argmax(axis=2)] beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] + # Part of Eqs. 46 & 47; i.e. transform filters to channel space + alpha_Ubar = np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3)) + beta_Ubar = np.matmul(U_bar_bb, np.expand_dims(beta, axis=3)) + # Eq. 46 (seed spatial patterns) self.patterns[0, con_i, :n_seeds] = ( - np.matmul( - np.real(C[..., :n_seeds, :n_seeds]), - np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3)), - ) + np.matmul(np.real(C[..., :n_seeds, :n_seeds]), alpha_Ubar) )[..., 0].T # Eq. 47 (target spatial patterns) self.patterns[1, con_i, :n_targets] = ( - np.matmul( - np.real(C[..., n_seeds:, n_seeds:]), - np.matmul(U_bar_bb, np.expand_dims(beta, axis=3)), - ) + np.matmul(np.real(C[..., n_seeds:, n_seeds:]), beta_Ubar) )[..., 0].T # Eq. 7 @@ -461,6 +469,10 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i): * np.linalg.norm(beta, axis=2) ).T + if self.store_filters: + self.filters[0, con_i, :n_seeds] = alpha_Ubar + self.filters[1, con_i, :n_targets] = beta_Ubar + def _compute_mim(self, E, seed_idcs, target_idcs, con_i): """Compute MIM (a.k.a. GIM if seeds == targets) for one connection.""" # Eq. 14 @@ -641,16 +653,24 @@ def _compute_patterns( alpha = np.matmul(T_aa, np.expand_dims(a, axis=3)) # filter for seeds beta = np.matmul(T_bb, np.expand_dims(b, axis=3)) # filter for targets - # Eq. 14; U_bar inclusion follows Eqs. 46 & 47 of Ewald et al. (2012) + # Eqs. 46 & 47 of Ewald et al. (2012); i.e. transform filters to channel space + alpha_Ubar = np.matmul(U_bar_aa, alpha) + beta_Ubar = np.matmul(U_bar_bb, beta) + + # Eq. 14 # seed spatial patterns self.patterns[0, con_i, :n_seeds] = ( - np.matmul(np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, alpha)) + np.matmul(np.real(C[..., :n_seeds, :n_seeds]), alpha_Ubar) )[..., 0].T # target spatial patterns self.patterns[1, con_i, :n_targets] = ( - np.matmul(np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, beta)) + np.matmul(np.real(C[..., n_seeds:, n_seeds:]), beta_Ubar) )[..., 0].T + if self.store_filters: + self.filters[0, con_i, :n_seeds] = alpha_Ubar + self.filters[1, con_i, :n_targets] = beta_Ubar + class _GCEstBase(_EpochMeanMultivariateConEstBase): """Base multivariate state-space Granger causality estimator.""" @@ -844,9 +864,7 @@ def _whittle_lwr_recursion(self, G): ) # forward autocov G_b = np.reshape( np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), order="F" - ).transpose( - 0, 2, 1 - ) # backward autocov + ).transpose(0, 2, 1) # backward autocov A_f = np.zeros((t, n, qn)) # forward coefficients A_b = np.zeros((t, n, qn)) # backward coefficients From 68a0890054a72338c0ef89226b374c846911c5b7 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 29 May 2024 13:43:27 +0200 Subject: [PATCH 02/38] Refactor results reshaping --- .../spectral/epochs_multivariate.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index f5a173ea..d7eda288 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -179,6 +179,15 @@ def reshape_csd(self): self._acc, (self.n_signals, self.n_signals, self.n_freqs, self.n_times) ).transpose(3, 2, 0, 1) + def reshape_results(self): + """Remove time dimension from results, if necessary.""" + if self.n_times == 0: + self.con_scores = self.con_scores[..., 0] + if self.patterns is not None: + self.patterns = self.patterns[..., 0] + if self.filters is not None: + self.filters = self.filters[..., 0] + class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): """Base estimator for multivariate coherency methods. @@ -323,13 +332,6 @@ def _compute_t(self, C_r, n_seeds): return np.real(T) # make T real if check passes - def reshape_results(self): - """Remove time dimension from results, if necessary.""" - if self.n_times == 0: - self.con_scores = self.con_scores[..., 0] - if self.patterns is not None: - self.patterns = self.patterns[..., 0] - def _invsqrtm(C, T, n_seeds): """Compute inverse sqrt of CSD over times (used for CaCoh, MIC, & MIM). @@ -1027,11 +1029,6 @@ def _partial_covar(self, V, seeds, targets): return V[np.ix_(times, seeds, seeds)] - W - def reshape_results(self): - """Remove time dimension from con. scores, if necessary.""" - if self.n_times == 0: - self.con_scores = self.con_scores[:, :, 0] - def _gc_compute_H(A, C, K, z_k, I_n, I_m): """Compute transfer function for innovations-form state-space params. From 0368367d17cbd46a42dd2530d063c1009ad28784 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 29 May 2024 13:43:52 +0200 Subject: [PATCH 03/38] Fix filter indexing for storage --- mne_connectivity/spectral/epochs_multivariate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index d7eda288..74545d95 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -472,8 +472,8 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i): ).T if self.store_filters: - self.filters[0, con_i, :n_seeds] = alpha_Ubar - self.filters[1, con_i, :n_targets] = beta_Ubar + self.filters[0, con_i, :n_seeds] = alpha_Ubar[..., 0].T + self.filters[1, con_i, :n_targets] = beta_Ubar[..., 0].T def _compute_mim(self, E, seed_idcs, target_idcs, con_i): """Compute MIM (a.k.a. GIM if seeds == targets) for one connection.""" @@ -670,8 +670,8 @@ def _compute_patterns( )[..., 0].T if self.store_filters: - self.filters[0, con_i, :n_seeds] = alpha_Ubar - self.filters[1, con_i, :n_targets] = beta_Ubar + self.filters[0, con_i, :n_seeds] = alpha_Ubar[..., 0].T + self.filters[1, con_i, :n_targets] = beta_Ubar[..., 0].T class _GCEstBase(_EpochMeanMultivariateConEstBase): From bb75cb10244e13c08529c89c03daf2de6a5bca86 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 29 May 2024 13:44:16 +0200 Subject: [PATCH 04/38] Update fill_doc dictionary --- mne_connectivity/utils/docs.py | 146 +++++++++++++++++++++++---------- 1 file changed, 101 insertions(+), 45 deletions(-) diff --git a/mne_connectivity/utils/docs.py b/mne_connectivity/utils/docs.py index bc379e36..8905f519 100644 --- a/mne_connectivity/utils/docs.py +++ b/mne_connectivity/utils/docs.py @@ -14,9 +14,7 @@ docdict = dict() # Connectivity -docdict[ - "data" -] = """ +docdict["data"] = """ data : np.ndarray ([epochs], n_estimated_nodes, [freqs], [times]) The connectivity data that is a raveled array of ``(n_estimated_nodes, ...)`` shape. The @@ -26,9 +24,7 @@ equal to the length of ``indices`` passed in. """ -docdict[ - "names" -] = """ +docdict["names"] = """ names : list | np.ndarray | None The names of the nodes of the dataset used to compute connectivity. If 'None' (default), then names will be @@ -36,9 +32,7 @@ of names, then it must be equal in length to ``n_nodes``. """ -docdict[ - "indices" -] = """ +docdict["indices"] = """ indices : tuple of arrays | str | None The indices of relevant connectivity data. If ``'all'`` (default), then data is connectivity between all nodes. If ``'symmetric'``, @@ -47,18 +41,14 @@ represents the "out nodes". See "Notes" for more information. """ -docdict[ - "n_nodes" -] = """ +docdict["n_nodes"] = """ n_nodes : int The number of nodes in the dataset used to compute connectivity. This should be equal to the number of signals in the original dataset. """ -docdict[ - "connectivity_kwargs" -] = """ +docdict["connectivity_kwargs"] = """ **kwargs : dict Extra connectivity parameters. These may include ``freqs`` for spectral connectivity, and/or @@ -67,6 +57,25 @@ as xarray ``attrs``. """ +docdict["mt_bandwidth"] = """ +mt_bandwidth : int | float | None (default None) + The bandwidth of the multitaper windowing function in Hz to use when + computing the cross-spectral density. Only used if ``mode="multitaper"``. +""" + +docdict["mt_adaptive"] = """ +mt_adaptive : bool (default False) + Whether to use adaptive weights when combining the tapered spectra in the + cross-spectral density. Only used if ``mode="multitaper"``. +""" + +docdict["mt_low_bias"] = """ +mt_low_bias : bool (default True) + Whether to use tapers with over 90 percent spectral concentration within + the bandwidth when computing the cross-spectral density. Only used if + ``mode="multitaper"``. +""" + docdict["coh"] = "'coh' : Coherence" docdict["cohy"] = "'cohy' : Coherency" docdict["imcoh"] = "'imcoh' : Imaginary part of Coherency" @@ -85,57 +94,43 @@ docdict["gc_tr"] = "'gc_tr' : State-space GC on time-reversed signals" # Downstream container variables -docdict[ - "freqs" -] = """ +docdict["freqs"] = """ freqs : list | np.ndarray The frequencies at which the connectivity data is computed over. If the frequencies are "frequency bands" (i.e. gamma band), then these are the median of those bands. """ -docdict[ - "times" -] = """ +docdict["times"] = """ times : list | np.ndarray The times at which the connectivity data is computed over. """ -docdict[ - "method" -] = """ +docdict["method"] = """ method : str, optional The method name used to compute connectivity. """ -docdict[ - "spec_method" -] = """ +docdict["spec_method"] = """ spec_method : str, optional The type of method used to compute spectral analysis, by default None. """ -docdict[ - "n_epochs_used" -] = """ +docdict["n_epochs_used"] = """ n_epochs_used : int, optional The number of epochs used in the computation of connectivity, by default None. """ -docdict[ - "events" -] = """ +docdict["events"] = """ events : array of int, shape (n_events, 3) The events typically returned by the read_events function. If some events don't match the events of interest as specified by event_id, they will be marked as 'IGNORED' in the drop log. """ -docdict[ - "event_id" -] = """ +docdict["event_id"] = """ event_id : int | list of int | dict | None The id of the event to consider. If dict, the keys can later be used to access associated events. Example: @@ -147,27 +142,21 @@ """ # Verbose -docdict[ - "verbose" -] = """ +docdict["verbose"] = """ verbose : bool, str, int, or None If not None, override default verbose level (see :func:`mne.verbose` for more info). If used, it should be passed as a keyword-argument only.""" # Parallelization -docdict[ - "n_jobs" -] = """ +docdict["n_jobs"] = """ n_jobs : int The number of jobs to run in parallel (default 1). Requires the joblib package. """ # Random state -docdict[ - "random_state" -] = """ +docdict["random_state"] = """ random_state : None | int | instance of ~numpy.random.RandomState If ``random_state`` is an :class:`int`, it will be used as a seed for :class:`~numpy.random.RandomState`. If ``None``, the seed will be @@ -176,6 +165,73 @@ ``None``. """ +# Decoding +docdict["info_decoding"] = """ +info : mne.Info + Information about the data which will be decomposed and transformed, such + as that coming from an :class:`mne.Epochs` object. The number of channels + must match the subsequent input data. +""" + +docdict["fmin_decoding"] = """ +fmin : int | float + The lowest frequency of interest in Hz. +""" + +docdict["fmax_decoding"] = """ +fmax : int | float + The highest frequency of interest in Hz. +""" + +docdict["indices_decoding"] = """ +indices : tuple of array + A tuple of two arrays, containing the indices of the seed and target + channels in the input data, respectively. The indices of only a single + connection (i.e. between one group of seeds and one group of targets) is + supported. +""" + +docdict["mode_decoding"] = """ +mode : str (default "multitaper") + The cross-spectral density computation method. Can be ``"multitaper"`` or + ``"fourier"``. +""" + +docdict["n_components"] = """ +n_components : int | None (default None) + The number of connectivity components (sources) to extract from the data. + If `None`, the number of components equal to the minimum rank of the seeds + and targets is extracted (see the ``rank`` parameter). If an `int`, the + number of components must be <= the minimum rank of the seeds and targets. + E.g. if the seed channels had a rank of 5 and the target channels had a + rank of 3, ``n_components`` must be <= 3. +""" + +docdict["rank"] = """ +rank : tuple of int | None (default None) + A tuple of two ints, containing the degree of rank subspace projection to + apply to the seed and target data, respectively, before filters are fit. If + `None`, the rank of the seed and target data is used. If a tuple of ints, + the entries must be <= the rank of the seed and target data. The minimum + rank of the seeds and targets determines the maximum number of connectivity + components (sources) which can be extracted from the data (see the + ``n_components`` parameter). Specifying ranks below that of the data may + reduce the degree of overfitting when computing the filters. +""" + +docdict["filters_"] = """ +filters_ : tuple of array, shape=(n_signals, n_components) + A tuple of two arrays containing the spatial filters for transforming the + seed and target data, respectively. +""" + +docdict["patterns_"] = """ +patterns_ : tuple of array, shape=(n_components, n_signals) + A tuple of two arrays containing the spatial patterns corresponding to the + spatial filters for the seed and target data, respectively. +""" + + docdict_indented = dict() # type: ignore From 115757cc0ce96320f4ae41965a72fd8e50bafbb8 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 29 May 2024 13:44:41 +0200 Subject: [PATCH 05/38] Add n_components to ingored numpydoc words --- doc/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/conf.py b/doc/conf.py index 05347f53..c2765602 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -111,6 +111,7 @@ "n_node_names", "n_tapers", "n_signals", + "n_components", "n_step", "n_freqs", "epochs", From e7c9da9e4379207bff5feb7c4684daf0f92efef6 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 29 May 2024 13:44:57 +0200 Subject: [PATCH 06/38] Add decoding module --- mne_connectivity/decoding/__init__.py | 1 + mne_connectivity/decoding/coherency.py | 445 +++++++++++++++++++++++++ 2 files changed, 446 insertions(+) create mode 100644 mne_connectivity/decoding/__init__.py create mode 100644 mne_connectivity/decoding/coherency.py diff --git a/mne_connectivity/decoding/__init__.py b/mne_connectivity/decoding/__init__.py new file mode 100644 index 00000000..8cf662c2 --- /dev/null +++ b/mne_connectivity/decoding/__init__.py @@ -0,0 +1 @@ +from .coherency import MIC, CaCoh diff --git a/mne_connectivity/decoding/coherency.py b/mne_connectivity/decoding/coherency.py new file mode 100644 index 00000000..00b00419 --- /dev/null +++ b/mne_connectivity/decoding/coherency.py @@ -0,0 +1,445 @@ +# Authors: Thomas S. Binns +# +# License: BSD (3-clause) + +from typing import Optional + +import numpy as np +from mne import Info +from mne.decoding.mixin import TransformerMixin +from mne.fixes import BaseEstimator +from mne.time_frequency import csd_array_fourier, csd_array_multitaper +from mne.utils import _check_option, _validate_type + +from ..spectral.epochs_multivariate import ( + _CaCohEst, + _check_rank_input, + _EpochMeanMultivariateConEstBase, + _MICEst, +) +from ..utils import _check_multivariate_indices, fill_doc + + +class _AbstractDecompositionBase(BaseEstimator, TransformerMixin): + """ABC for multivariate connectivity signal decomposition.""" + + filters_: Optional[tuple] = None + patterns_: Optional[tuple] = None + + _indices: Optional[tuple] = None + _rank: Optional[tuple] = None + _conn_estimator: Optional[_EpochMeanMultivariateConEstBase] = None + + @property + def indices(self): + """Get ``indices`` parameter in the input format.""" + return (self._indices[0].compressed(), self._indices[1].compressed()) + + @indices.setter + def indices(self, indices): + """Set ``indices`` parameter using the input format.""" + self._indices = (np.array([indices[0]]), np.array([indices[1]])) + + @property + def rank(self): + """Get ``rank`` parameter in the input format.""" + return (self._rank[0][0], self._rank[1][0]) + + @rank.setter + def rank(self, rank): + """Set ``rank`` parameter using the input format.""" + self._rank = ([rank[0]], [rank[1]]) + + def __init__( + self, + info, + fmin, + fmax, + indices, + mode="multitaper", + mt_bandwidth=None, + mt_adaptive=False, + mt_low_bias=True, + n_components=None, + rank=None, + n_jobs=1, + verbose=None, + ): + """Initialise instance.""" + # Validate inputs + _validate_type(info, Info, "`info`", "mne.Info") + + _validate_type(fmin, (int, float), "`fmin`", "int or float") + _validate_type(fmax, (int, float), "`fmax`", "int or float") + if fmin > fmax: + raise ValueError("`fmax` must be larger than `fmin`") + if fmax > info["sfreq"] / 2: + raise ValueError("`fmax` cannot be larger than the Nyquist frequency") + + _validate_type(indices, tuple, "`indices`", "tuple of lists") + if len(indices) != 2: + raise ValueError("`indices` must be have length 2") + for indices_group in indices: + _validate_type( + indices_group, + (list, tuple, np.ndarray), + "`indices`", + "tuple of lists, tuples, or NumPy arrays", + ) + _indices = self._check_indices(indices, info["nchan"]) + + _check_option("mode", mode, ("multitaper", "fourier")) + _validate_type( + mt_bandwidth, (int, float, None), "`mt_bandwidth`", "int, float, or None" + ) + _validate_type(mt_adaptive, bool, "`mt_adaptive`", "bool") + _validate_type(mt_low_bias, bool, "`mt_low_bias`", "bool") + + _validate_type(n_components, (int, None), "`n_components`", "int or None") + + _validate_type(rank, (tuple, None), "`rank`", "tuple of ints or None") + if rank is not None: + if len(rank != 2): + raise ValueError("`rank` must be have length 2") + for rank_group in rank: + _validate_type(rank_group, int, "`rank`", "tuple of ints or None") + _rank = self._check_rank(rank, indices) + + _validate_type(n_jobs, int, "`n_jobs`", "int") + + _validate_type( + verbose, (bool, str, int, None), "`verbose`", "bool, str, int, or None" + ) + + # Store inputs + self.info = info + self.fmin = fmin + self.fmax = fmax + self._indices = _indices # uses getter/setter for public parameter + self.mode = mode + self.mt_bandwidth = mt_bandwidth + self.mt_adaptive = mt_adaptive + self.mt_low_bias = mt_low_bias + self.n_components = 1 # XXX: fixed until n_comps > 1 supported + self._rank = _rank # uses getter/setter for public parameter + self.n_jobs = n_jobs + self.verbose = verbose + + def _check_indices(self, indices, n_chans): + """Check that the indices input is valid.""" + # convert to multivariate format and check validity + indices = _check_multivariate_indices(([indices[0]], [indices[1]]), n_chans) + + # find whether entries of indices exceed number of channels + max_idx = np.max(indices.compressed()) + if max_idx + 1 > n_chans: + raise ValueError( + "At least one entry in `indices` is greater than the number " + "of channels in `info`" + ) + + return indices + + def _check_rank(self, rank, indices): + """Check that the rank input is valid.""" + if rank is not None: + # convert to multivariate format + rank = ([rank[0]], [rank[1]]) + + # find whether entries of rank exceed number of channels in indices + if rank[0][0] > len(indices[0]) or rank[1][0] > len(indices[1]): + raise ValueError( + "At least one entry in `rank` is greater than the number " + "of seed/target channels in `indices`" + ) + + return rank + + def fit(self, X, y=None): + """Compute connectivity decomposition filters for epoched data. + + Parameters + ---------- + X : array, shape=(n_epochs, n_signals, n_times) + The input data which the connectivity decomposition filters should + be fit to. + y : None + Used for scikit-learn compatibility. + + Returns + ------- + self : instance of CaCoh | MIC + The modified class instance. + """ + # validate input data + self._check_X(X, ndim=[3]) + self._get_rank_and_ncomps_from_X(X) + + # compute CSD + csd = self._compute_csd(X) + + # instantiate connectivity estimator and add CSD information + self._conn_estimator = self._conn_estimator( + n_signals=X.shape[1], + n_cons=1, + n_freqs=1, + n_times=0, + n_jobs=self.n_jobs, + store_filters=True, + ) + self._conn_estimator.accumulate(con_idx=np.arange(csd.shape[0]), csd_xy=csd) + + # fit filters to data and compute corresponding patterns + self._conn_estimator.compute_con( + indices=self._indices, ranks=self._rank, n_epochs=1 + ) + + # extract filters and patterns + self._extract_filters_and_patterns() + + return self + + def _check_X(self, X, ndim): + """Check that the input data is valid.""" + # check data is a 2/3D array + _validate_type(X, np.ndarray, "`X`", "NumPy array") + _check_option("`X.ndim`", X.ndim, ndim) + n_chans = X.shape[1] + if n_chans != self.info["nchan"]: + raise ValueError( + "`X` does not match Info\nExpected %i channels, got %i" + % (n_chans, self.info["nchan"]) + ) + + def _get_rank_and_ncomps_from_X(self, X): + """Get/validate rank and n_components parameters using the data.""" + # compute rank from data if necessary / check it is valid for the indices + self._rank = _check_rank_input(self._rank, X, self._indices) + + # set n_components if necessary / check it is valid for the rank + if self.n_components is None: + self.n_components = np.min(self.rank) + elif self.n_components > np.min(self.rank): + raise ValueError( + "`n_components` is greater than the minimum rank of the data" + ) + + def _compute_csd(self, X): + """Compute the cross-spectral density of the input data.""" + # XXX: fix csd returning [fmin +1 bin to fmax -1 bin] frequencies + csd_kwargs = { + "X": X, + "sfreq": self.info["sfreq"], + "fmin": self.fmin, + "fmax": self.fmax, + "n_jobs": self.n_jobs, + } + if self.mode == "multitaper": + csd_kwargs.update( + { + "bandwidth": self.mt_bandwidth, + "adaptive": self.mt_adaptive, + "low_bias": self.mt_low_bias, + } + ) + csd = csd_array_multitaper(**csd_kwargs) + else: + csd = csd_array_fourier(**csd_kwargs) + + csd = csd.sum(self.fmin, self.fmax).get_data(index=0) + csd = np.reshape(csd, csd.shape[0] ** 2) + + return np.expand_dims(csd, 1) + + def _extract_filters_and_patterns(self): + """Extract filters and patterns from the connectivity estimator.""" + # XXX: need to sort indices and transpose patterns when multiple comps supported + self.filters_ = ( + self._conn_estimator.filters[0, 0, : len(self.indices[0]), 0], + self._conn_estimator.filters[1, 0, : len(self.indices[1]), 0], + ) + + self.patterns_ = ( + self._conn_estimator.patterns[0, 0, : len(self.indices[0]), 0], + self._conn_estimator.patterns[1, 0, : len(self.indices[1]), 0], + ) + + # XXX: remove once support for multiple comps implemented + self.filters_ = ( + np.expand_dims(self.filters_[0], 1), + np.expand_dims(self.filters_[1], 1), + ) + self.patterns_ = ( + np.expand_dims(self.patterns_[0], 0), + np.expand_dims(self.patterns_[1], 0), + ) + + def transform(self, X): + """Decompose data into connectivity sources using the fitted filters. + + Parameters + ---------- + X : array, shape=((n_epochs, ) n_signals, n_times) + The data to be transformed by the connectivity decoposition + filters. + + Returns + ------- + X_transformed : array, shape=((n_epochs, ) n_components*2, n_times) + The transformed data. The first ``n_components`` channels are the + transformed seeds, and the last ``n_components`` channels are the + transformed targets. + """ + self._check_X(X, ndim=(2, 3)) + if self.filters_ is None: + raise RuntimeError( + "no filters are available, please call the `fit` method first" + ) + + # transform seed and target data + X_seeds = self.filters_[0].T @ X[..., self.indices[0], :] + X_targets = self.filters_[1].T @ X[..., self.indices[1], :] + + return np.concatenate((X_seeds, X_targets), axis=-2) + + def fit_transform(self, X, y=None, **fit_params): + """Fit filters to data, then transform and return it. + + Parameters + ---------- + X : array, shape=(n_epochs, n_signals, n_times) + The input data which the connectivity decomposition filters should + be fit to and subsequently transformed. + y : None + Used for scikit-learn compatibility. + **fit_params : dict + Additional fitting parameters passed to the ``fit`` method. Not + used for this class. + + Returns + ------- + X_transformed : array, shape=(n_epochs, n_components*2, n_times) + The transformed data. The first ``n_components`` channels are the + transformed seeds, and the last ``n_components`` channels are the + transformed targets. + """ + # custom docstring, but uses parent TransformerMixin method + + def get_transformed_indices(self): + """Get indices for the transformed data. + + Returns + ------- + indices_transformed : tuple of array + Indices of seeds and targets in the transformed data with the form + (seeds, targets) to be used when passing the data to + `~mne_connectivity.spectral_connectivity_epochs` and + `~mne_connectivity.spectral_connectivity_time`. Entries of the + indices are arranged such that connectivity would be computed + between the first seed component and first target component, second + seed component and second target component, etc... + """ + return ( + np.arange(self.n_components), + np.arange(self.n_components) + self.n_components, + ) + + +@fill_doc +class CaCoh(_AbstractDecompositionBase): + """Decompose connectivity sources using canonical coherency (CaCoh). + + CaCoh is a multivariate approach to maximise coherency/coherence between a + set of seed and target signals in a frequency-resolved manner + :footcite:`VidaurreEtAl2019`. The maximisation of connectivity involves + fitting spatial filters to the cross-spectral density of the seed and + target data, alongisde which spatial patterns of the contributions to + connectivity can be computed :footcite:`HaufeEtAl2014`. + + Once fit, the filters can be used to transform data into the underlying + connectivity components. Connectivity can be computed on this transformed + data using the ``"coh"`` and ``"cohy"`` methods of the + `mne_connectivity.spectral_connectivity_epochs` and + `mne_connectivity.spectral_connectivity_time` functions. + + The approach taken here is to optimise the connectivity in a given + frequency band. Frequency bin-wise optimisation is offered in the + ``"cacoh"`` method of the `mne_connectivity.spectral_connectivity_epochs` + and `mne_connectivity.spectral_connectivity_time` functions. + + Parameters + ---------- + %(info_decoding)s + %(fmin_decoding)s + %(fmax_decoding)s + %(indices_decoding)s + %(mode_decoding)s + %(mt_bandwidth)s + %(mt_adaptive)s + %(mt_low_bias)s + %(n_components)s + %(rank)s + %(n_jobs)s + %(verbose)s + + Attributes + ---------- + %(filters_)s + %(patterns_)s + + References + ---------- + .. footbibliography:: + """ + + _conn_estimator = _CaCohEst + + +@fill_doc +class MIC(_AbstractDecompositionBase): + """Decompose connectivity sources using maximised imaginary coherency (MIC). + + MIC is a multivariate approach to maximise the imaginary part of coherency + between a set of seed and target signals in a frequency-resolved manner + :footcite:`EwaldEtAl2012`. The maximisation of connectivity involves + fitting spatial filters to the cross-spectral density of the seed and + target data, alongisde which spatial patterns of the contributions to + connectivity can be computed :footcite:`HaufeEtAl2014`. + + Once fit, the filters can be used to transform data into the underlying + connectivity components. Connectivity can be computed on this transformed + data using the ``"imcoh"`` method of the + `mne_connectivity.spectral_connectivity_epochs` and + `mne_connectivity.spectral_connectivity_time` functions. + + The approach taken here is to optimise the connectivity in a given + frequency band. Frequency bin-wise optimisation is offered in the + ``"mic"`` method of the `mne_connectivity.spectral_connectivity_epochs` + and `mne_connectivity.spectral_connectivity_time` functions. + + Parameters + ---------- + %(info_decoding)s + %(fmin_decoding)s + %(fmax_decoding)s + %(indices_decoding)s + %(mode_decoding)s + %(mt_bandwidth)s + %(mt_adaptive)s + %(mt_low_bias)s + %(n_components)s + %(rank)s + %(n_jobs)s + %(verbose)s + + Attributes + ---------- + %(filters_)s + %(patterns_)s + + References + ---------- + .. footbibliography:: + """ + + _conn_estimator = _MICEst From 5e51923834994312a921070c29e7ee08705086dd Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 29 May 2024 13:46:48 +0200 Subject: [PATCH 07/38] Update API with decoding module --- doc/api.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index f919f74b..c64132cc 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -49,9 +49,25 @@ on numpy array inputs. spectral_connectivity_epochs spectral_connectivity_time +Decoding classes +================ + +These classes fit filters which decompose data into discrete sources of +connectivity, amplifying the signal-to-noise ratio of these interactions. + +.. currentmodule:: mne_connectivity.decoding + +.. autosummary:: + :toctree: generated/ + + CaCoh + MIC + Reading functions ================= +.. currentmodule:: mne_connectivity + .. autosummary:: :toctree: generated/ From 9fe0f2859da44420b11d0ecc8e3feb921235e709 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 29 May 2024 17:06:03 +0200 Subject: [PATCH 08/38] Rename file and add suport for cwt_morlet mode --- mne_connectivity/decoding/__init__.py | 2 +- .../{coherency.py => decomposition.py} | 141 +++++++++++------- mne_connectivity/utils/docs.py | 79 +++++----- 3 files changed, 130 insertions(+), 92 deletions(-) rename mne_connectivity/decoding/{coherency.py => decomposition.py} (79%) diff --git a/mne_connectivity/decoding/__init__.py b/mne_connectivity/decoding/__init__.py index 8cf662c2..444470ae 100644 --- a/mne_connectivity/decoding/__init__.py +++ b/mne_connectivity/decoding/__init__.py @@ -1 +1 @@ -from .coherency import MIC, CaCoh +from .decomposition import MIC, CaCoh diff --git a/mne_connectivity/decoding/coherency.py b/mne_connectivity/decoding/decomposition.py similarity index 79% rename from mne_connectivity/decoding/coherency.py rename to mne_connectivity/decoding/decomposition.py index 00b00419..80f86e26 100644 --- a/mne_connectivity/decoding/coherency.py +++ b/mne_connectivity/decoding/decomposition.py @@ -8,7 +8,7 @@ from mne import Info from mne.decoding.mixin import TransformerMixin from mne.fixes import BaseEstimator -from mne.time_frequency import csd_array_fourier, csd_array_multitaper +from mne.time_frequency import csd_array_fourier, csd_array_morlet, csd_array_multitaper from mne.utils import _check_option, _validate_type from ..spectral.epochs_multivariate import ( @@ -60,6 +60,8 @@ def __init__( mt_bandwidth=None, mt_adaptive=False, mt_low_bias=True, + cwt_freq_resolution=1, + cwt_n_cycles=7, n_components=None, rank=None, n_jobs=1, @@ -76,7 +78,7 @@ def __init__( if fmax > info["sfreq"] / 2: raise ValueError("`fmax` cannot be larger than the Nyquist frequency") - _validate_type(indices, tuple, "`indices`", "tuple of lists") + _validate_type(indices, tuple, "`indices`", "tuple of array-like") if len(indices) != 2: raise ValueError("`indices` must be have length 2") for indices_group in indices: @@ -84,17 +86,28 @@ def __init__( indices_group, (list, tuple, np.ndarray), "`indices`", - "tuple of lists, tuples, or NumPy arrays", + "tuple of array-likes", ) _indices = self._check_indices(indices, info["nchan"]) - _check_option("mode", mode, ("multitaper", "fourier")) + _check_option("mode", mode, ("multitaper", "fourier", "cwt_morlet")) + _validate_type( mt_bandwidth, (int, float, None), "`mt_bandwidth`", "int, float, or None" ) _validate_type(mt_adaptive, bool, "`mt_adaptive`", "bool") _validate_type(mt_low_bias, bool, "`mt_low_bias`", "bool") + _validate_type( + cwt_freq_resolution, (int, float), "`cwt_freq_resolution`", "int or float" + ) + _validate_type( + cwt_n_cycles, + (int, float, tuple, list, np.ndarray), + "`cwt_n_cycles`", + "int, float, or array-like of ints or floats", + ) + _validate_type(n_components, (int, None), "`n_components`", "int or None") _validate_type(rank, (tuple, None), "`rank`", "tuple of ints or None") @@ -120,6 +133,8 @@ def __init__( self.mt_bandwidth = mt_bandwidth self.mt_adaptive = mt_adaptive self.mt_low_bias = mt_low_bias + self.cwt_freq_resolution = cwt_freq_resolution + self.cwt_n_cycles = cwt_n_cycles self.n_components = 1 # XXX: fixed until n_comps > 1 supported self._rank = _rank # uses getter/setter for public parameter self.n_jobs = n_jobs @@ -134,8 +149,8 @@ def _check_indices(self, indices, n_chans): max_idx = np.max(indices.compressed()) if max_idx + 1 > n_chans: raise ValueError( - "At least one entry in `indices` is greater than the number " - "of channels in `info`" + "At least one entry in `indices` is greater than the number of " + "channels in `info`" ) return indices @@ -149,8 +164,8 @@ def _check_rank(self, rank, indices): # find whether entries of rank exceed number of channels in indices if rank[0][0] > len(indices[0]) or rank[1][0] > len(indices[1]): raise ValueError( - "At least one entry in `rank` is greater than the number " - "of seed/target channels in `indices`" + "At least one entry in `rank` is greater than the number of " + "seed/target channels in `indices`" ) return rank @@ -161,8 +176,8 @@ def fit(self, X, y=None): Parameters ---------- X : array, shape=(n_epochs, n_signals, n_times) - The input data which the connectivity decomposition filters should - be fit to. + The input data which the connectivity decomposition filters should be fit + to. y : None Used for scikit-learn compatibility. @@ -230,21 +245,34 @@ def _compute_csd(self, X): csd_kwargs = { "X": X, "sfreq": self.info["sfreq"], - "fmin": self.fmin, - "fmax": self.fmax, "n_jobs": self.n_jobs, } if self.mode == "multitaper": csd_kwargs.update( { + "fmin": self.fmin, + "fmax": self.fmax, "bandwidth": self.mt_bandwidth, "adaptive": self.mt_adaptive, "low_bias": self.mt_low_bias, } ) csd = csd_array_multitaper(**csd_kwargs) - else: + elif self.mode == "fourier": + csd_kwargs.update({"fmin": self.fmin, "fmax": self.fmax}) csd = csd_array_fourier(**csd_kwargs) + else: + csd_kwargs.update( + { + "frequencies": np.arange( + self.fmin, + self.fmax + self.cwt_freq_resolution, + self.cwt_freq_resolution, + ), + "n_cycles": self.cwt_n_cycles, + } + ) + csd = csd_array_morlet(**csd_kwargs) csd = csd.sum(self.fmin, self.fmax).get_data(index=0) csd = np.reshape(csd, csd.shape[0] ** 2) @@ -280,8 +308,7 @@ def transform(self, X): Parameters ---------- X : array, shape=((n_epochs, ) n_signals, n_times) - The data to be transformed by the connectivity decoposition - filters. + The data to be transformed by the connectivity decomposition filters. Returns ------- @@ -308,13 +335,13 @@ def fit_transform(self, X, y=None, **fit_params): Parameters ---------- X : array, shape=(n_epochs, n_signals, n_times) - The input data which the connectivity decomposition filters should - be fit to and subsequently transformed. + The input data which the connectivity decomposition filters should be fit to + and subsequently transformed. y : None Used for scikit-learn compatibility. **fit_params : dict - Additional fitting parameters passed to the ``fit`` method. Not - used for this class. + Additional fitting parameters passed to the ``fit`` method. Not used for + this class. Returns ------- @@ -323,7 +350,7 @@ def fit_transform(self, X, y=None, **fit_params): transformed seeds, and the last ``n_components`` channels are the transformed targets. """ - # custom docstring, but uses parent TransformerMixin method + # use parent TransformerMixin method but with custom docstring def get_transformed_indices(self): """Get indices for the transformed data. @@ -331,13 +358,13 @@ def get_transformed_indices(self): Returns ------- indices_transformed : tuple of array - Indices of seeds and targets in the transformed data with the form - (seeds, targets) to be used when passing the data to + Indices of seeds and targets in the transformed data with the form (seeds, + targets) to be used when passing the data to `~mne_connectivity.spectral_connectivity_epochs` and - `~mne_connectivity.spectral_connectivity_time`. Entries of the - indices are arranged such that connectivity would be computed - between the first seed component and first target component, second - seed component and second target component, etc... + `~mne_connectivity.spectral_connectivity_time`. Entries of the indices are + arranged such that connectivity would be computed between the first seed + component and first target component, second seed component and second + target component, etc... """ return ( np.arange(self.n_components), @@ -349,23 +376,22 @@ def get_transformed_indices(self): class CaCoh(_AbstractDecompositionBase): """Decompose connectivity sources using canonical coherency (CaCoh). - CaCoh is a multivariate approach to maximise coherency/coherence between a - set of seed and target signals in a frequency-resolved manner - :footcite:`VidaurreEtAl2019`. The maximisation of connectivity involves - fitting spatial filters to the cross-spectral density of the seed and - target data, alongisde which spatial patterns of the contributions to - connectivity can be computed :footcite:`HaufeEtAl2014`. - - Once fit, the filters can be used to transform data into the underlying - connectivity components. Connectivity can be computed on this transformed - data using the ``"coh"`` and ``"cohy"`` methods of the + CaCoh is a multivariate approach to maximise coherency/coherence between a set of + seed and target signals in a frequency-resolved manner :footcite:`VidaurreEtAl2019`. + The maximisation of connectivity involves fitting spatial filters to the + cross-spectral density of the seed and target data, alongisde which spatial patterns + of the contributions to connectivity can be computed :footcite:`HaufeEtAl2014`. + + Once fit, the filters can be used to transform data into the underlying connectivity + components. Connectivity can be computed on this transformed data using the + ``"coh"`` and ``"cohy"`` methods of the `mne_connectivity.spectral_connectivity_epochs` and `mne_connectivity.spectral_connectivity_time` functions. - The approach taken here is to optimise the connectivity in a given - frequency band. Frequency bin-wise optimisation is offered in the - ``"cacoh"`` method of the `mne_connectivity.spectral_connectivity_epochs` - and `mne_connectivity.spectral_connectivity_time` functions. + The approach taken here is to optimise the connectivity in a given frequency band. + Frequency bin-wise optimisation is offered in the ``"cacoh"`` method of the + `mne_connectivity.spectral_connectivity_epochs` and + `mne_connectivity.spectral_connectivity_time` functions. Parameters ---------- @@ -373,10 +399,12 @@ class CaCoh(_AbstractDecompositionBase): %(fmin_decoding)s %(fmax_decoding)s %(indices_decoding)s - %(mode_decoding)s + %(mode)s %(mt_bandwidth)s %(mt_adaptive)s %(mt_low_bias)s + %(cwt_freq_resolution)s + %(cwt_n_cycles)s %(n_components)s %(rank)s %(n_jobs)s @@ -399,34 +427,35 @@ class CaCoh(_AbstractDecompositionBase): class MIC(_AbstractDecompositionBase): """Decompose connectivity sources using maximised imaginary coherency (MIC). - MIC is a multivariate approach to maximise the imaginary part of coherency - between a set of seed and target signals in a frequency-resolved manner - :footcite:`EwaldEtAl2012`. The maximisation of connectivity involves - fitting spatial filters to the cross-spectral density of the seed and - target data, alongisde which spatial patterns of the contributions to - connectivity can be computed :footcite:`HaufeEtAl2014`. + MIC is a multivariate approach to maximise the imaginary part of coherency between a + set of seed and target signals in a frequency-resolved manner + :footcite:`EwaldEtAl2012`. The maximisation of connectivity involves fitting spatial + filters to the cross-spectral density of the seed and target data, alongisde which + spatial patterns of the contributions to connectivity can be computed + :footcite:`HaufeEtAl2014`. + + Once fit, the filters can be used to transform data into the underlying connectivity + components. Connectivity can be computed on this transformed data using the + ``"imcoh"`` method of the `mne_connectivity.spectral_connectivity_epochs` and + `mne_connectivity.spectral_connectivity_time` functions. - Once fit, the filters can be used to transform data into the underlying - connectivity components. Connectivity can be computed on this transformed - data using the ``"imcoh"`` method of the + The approach taken here is to optimise the connectivity in a given frequency band. + Frequency bin-wise optimisation is offered in the ``"mic"`` method of the `mne_connectivity.spectral_connectivity_epochs` and `mne_connectivity.spectral_connectivity_time` functions. - The approach taken here is to optimise the connectivity in a given - frequency band. Frequency bin-wise optimisation is offered in the - ``"mic"`` method of the `mne_connectivity.spectral_connectivity_epochs` - and `mne_connectivity.spectral_connectivity_time` functions. - Parameters ---------- %(info_decoding)s %(fmin_decoding)s %(fmax_decoding)s %(indices_decoding)s - %(mode_decoding)s + %(mode)s %(mt_bandwidth)s %(mt_adaptive)s %(mt_low_bias)s + %(cwt_freq_resolution)s + %(cwt_n_cycles)s %(n_components)s %(rank)s %(n_jobs)s diff --git a/mne_connectivity/utils/docs.py b/mne_connectivity/utils/docs.py index 8905f519..3fc6e5c8 100644 --- a/mne_connectivity/utils/docs.py +++ b/mne_connectivity/utils/docs.py @@ -57,10 +57,16 @@ as xarray ``attrs``. """ +docdict["mode"] = """ +mode : str (default "multitaper") + The cross-spectral density computation method. Can be ``"multitaper"``, + ``"fourier"``, or ``"cwt_morlet"``. +""" + docdict["mt_bandwidth"] = """ mt_bandwidth : int | float | None (default None) - The bandwidth of the multitaper windowing function in Hz to use when - computing the cross-spectral density. Only used if ``mode="multitaper"``. + The bandwidth of the multitaper windowing function in Hz to use when computing the + cross-spectral density. Only used if ``mode="multitaper"``. """ docdict["mt_adaptive"] = """ @@ -71,11 +77,23 @@ docdict["mt_low_bias"] = """ mt_low_bias : bool (default True) - Whether to use tapers with over 90 percent spectral concentration within - the bandwidth when computing the cross-spectral density. Only used if + Whether to use tapers with over 90 percent spectral concentration within the + bandwidth when computing the cross-spectral density. Only used if ``mode="multitaper"``. """ +docdict["cwt_freq_resolution"] = """ +cwt_freq_resolution : int | float (default 1) + The frequency resolution of the cross-spectral density in Hz. Only used if + ``mode=cwt_morlet``. +""" + +docdict["cwt_n_cycles"] = """ +cwt_n_cycles : int | float | array of int or float (default 7) + The number of cycles to use when constructing the Morlet wavelets. Fixed number or + one per frequency. Only used if ``mode=cwt_morlet``. +""" + docdict["coh"] = "'coh' : Coherence" docdict["cohy"] = "'cohy' : Coherency" docdict["imcoh"] = "'imcoh' : Imaginary part of Coherency" @@ -168,9 +186,9 @@ # Decoding docdict["info_decoding"] = """ info : mne.Info - Information about the data which will be decomposed and transformed, such - as that coming from an :class:`mne.Epochs` object. The number of channels - must match the subsequent input data. + Information about the data which will be decomposed and transformed, such as that + coming from an :class:`mne.Epochs` object. The number of channels must match the + subsequent input data. """ docdict["fmin_decoding"] = """ @@ -185,50 +203,41 @@ docdict["indices_decoding"] = """ indices : tuple of array - A tuple of two arrays, containing the indices of the seed and target - channels in the input data, respectively. The indices of only a single - connection (i.e. between one group of seeds and one group of targets) is - supported. -""" - -docdict["mode_decoding"] = """ -mode : str (default "multitaper") - The cross-spectral density computation method. Can be ``"multitaper"`` or - ``"fourier"``. + A tuple of two arrays, containing the indices of the seed and target channels in the + input data, respectively. The indices of only a single connection (i.e. between one + group of seeds and one group of targets) is supported. """ docdict["n_components"] = """ n_components : int | None (default None) - The number of connectivity components (sources) to extract from the data. - If `None`, the number of components equal to the minimum rank of the seeds - and targets is extracted (see the ``rank`` parameter). If an `int`, the - number of components must be <= the minimum rank of the seeds and targets. - E.g. if the seed channels had a rank of 5 and the target channels had a - rank of 3, ``n_components`` must be <= 3. + The number of connectivity components (sources) to extract from the data. If `None`, + the number of components equal to the minimum rank of the seeds and targets is + extracted (see the ``rank`` parameter). If an `int`, the number of components must + be <= the minimum rank of the seeds and targets. E.g. if the seed channels had a + rank of 5 and the target channels had a rank of 3, ``n_components`` must be <= 3. """ docdict["rank"] = """ rank : tuple of int | None (default None) - A tuple of two ints, containing the degree of rank subspace projection to - apply to the seed and target data, respectively, before filters are fit. If - `None`, the rank of the seed and target data is used. If a tuple of ints, - the entries must be <= the rank of the seed and target data. The minimum - rank of the seeds and targets determines the maximum number of connectivity - components (sources) which can be extracted from the data (see the - ``n_components`` parameter). Specifying ranks below that of the data may - reduce the degree of overfitting when computing the filters. + A tuple of two ints, containing the degree of rank subspace projection to apply to + the seed and target data, respectively, before filters are fit. If `None`, the rank + of the seed and target data is used. If a tuple of ints, the entries must be <= the + rank of the seed and target data. The minimum rank of the seeds and targets + determines the maximum number of connectivity components (sources) which can be + extracted from the data (see the ``n_components`` parameter). Specifying ranks below + that of the data may reduce the degree of overfitting when computing the filters. """ docdict["filters_"] = """ filters_ : tuple of array, shape=(n_signals, n_components) - A tuple of two arrays containing the spatial filters for transforming the - seed and target data, respectively. + A tuple of two arrays containing the spatial filters for transforming the seed and + target data, respectively. """ docdict["patterns_"] = """ patterns_ : tuple of array, shape=(n_components, n_signals) - A tuple of two arrays containing the spatial patterns corresponding to the - spatial filters for the seed and target data, respectively. + A tuple of two arrays containing the spatial patterns corresponding to the spatial + filters for the seed and target data, respectively. """ From fa6a001d63624d29fe78b5fb4fec518b3c8dc61f Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 3 Jun 2024 21:26:37 +0200 Subject: [PATCH 09/38] Make property docstrings private --- mne_connectivity/decoding/decomposition.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index 80f86e26..cbc93408 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -32,7 +32,10 @@ class _AbstractDecompositionBase(BaseEstimator, TransformerMixin): @property def indices(self): - """Get ``indices`` parameter in the input format.""" + """Get ``indices`` parameter in the input format. + + :meta private: + """ return (self._indices[0].compressed(), self._indices[1].compressed()) @indices.setter @@ -42,7 +45,10 @@ def indices(self, indices): @property def rank(self): - """Get ``rank`` parameter in the input format.""" + """Get ``rank`` parameter in the input format. + + :meta private: + """ return (self._rank[0][0], self._rank[1][0]) @rank.setter From 15e99e4cacd545556ec0a76f4aab0419a4d1ca3a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 3 Jun 2024 21:27:00 +0200 Subject: [PATCH 10/38] Bug fix error check --- mne_connectivity/decoding/decomposition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index cbc93408..d1285ad3 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -118,7 +118,7 @@ def __init__( _validate_type(rank, (tuple, None), "`rank`", "tuple of ints or None") if rank is not None: - if len(rank != 2): + if len(rank) != 2: raise ValueError("`rank` must be have length 2") for rank_group in rank: _validate_type(rank_group, int, "`rank`", "tuple of ints or None") From 58eca90da30dc8e75126661c65897eb19b817554 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 3 Jun 2024 21:27:22 +0200 Subject: [PATCH 11/38] Bug fix fit_transform no return --- mne_connectivity/decoding/decomposition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index d1285ad3..2ec7a1ef 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -357,6 +357,7 @@ def fit_transform(self, X, y=None, **fit_params): transformed targets. """ # use parent TransformerMixin method but with custom docstring + return super().fit_transform(X, y=y, **fit_params) def get_transformed_indices(self): """Get indices for the transformed data. From 9b43dfb6b444afd1069f158e86eb1bcbae3d50ef Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 3 Jun 2024 21:28:09 +0200 Subject: [PATCH 12/38] Bug fix _check_X 2d array --- mne_connectivity/decoding/decomposition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index 2ec7a1ef..d6311c57 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -225,7 +225,7 @@ def _check_X(self, X, ndim): # check data is a 2/3D array _validate_type(X, np.ndarray, "`X`", "NumPy array") _check_option("`X.ndim`", X.ndim, ndim) - n_chans = X.shape[1] + n_chans = X.shape[-2] if n_chans != self.info["nchan"]: raise ValueError( "`X` does not match Info\nExpected %i channels, got %i" From bb0b52052779b7e6aa5438909c0b80e951c3ba80 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 3 Jun 2024 21:32:19 +0200 Subject: [PATCH 13/38] Add preliminary decomp example --- examples/decoding/README.txt | 6 + examples/decoding/cohy_decomposition.py | 532 ++++++++++++++++++++++++ 2 files changed, 538 insertions(+) create mode 100644 examples/decoding/README.txt create mode 100644 examples/decoding/cohy_decomposition.py diff --git a/examples/decoding/README.txt b/examples/decoding/README.txt new file mode 100644 index 00000000..00535cf3 --- /dev/null +++ b/examples/decoding/README.txt @@ -0,0 +1,6 @@ + +Decoding & Decomposition Examples +--------------------------------- + +Examples demonstrating multivariate connectivity analysis using the decomposition tools +of the decoding module. \ No newline at end of file diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/cohy_decomposition.py new file mode 100644 index 00000000..c3c57285 --- /dev/null +++ b/examples/decoding/cohy_decomposition.py @@ -0,0 +1,532 @@ +""" +============================================================== +Multivariate decomposition for efficient connectivity analysis +============================================================== + +This example demonstrates how the tools in the decoding module can be used to +decompose data into the most relevant components of connectivity and used for +a computationally efficient multivariate analysis of connectivity, such as in +brain-computer interface (BCI) applications. +""" + +# Author: Thomas S. Binns +# License: BSD (3-clause) +# sphinx_gallery_thumbnail_number = 2 + +# %% + +import time + +import mne +import numpy as np +from matplotlib import pyplot as plt +from mne import make_fixed_length_epochs +from mne.datasets.fieldtrip_cmc import data_path + +from mne_connectivity import ( + make_signals_in_freq_bands, + seed_target_indices, + spectral_connectivity_epochs, +) +from mne_connectivity.decoding import MIC, CaCoh + +######################################################################################## +# Background +# ---------- +# +# Multivariate forms of signal analysis allow you to simultaneously consider +# the activity of multiple signals. In the case of connectivity, the +# interaction between multiple sensors can be analysed at once and the strongest +# components of this interaction captured in a lower-dimensional set of connectivity +# spectra. This approach brings not only practical benefits (e.g. easier +# interpretability of results from the dimensionality reduction), but can also offer +# methodological improvements (e.g. enhanced signal-to-noise ratio and reduced bias). +# +# Coherency-based methods are popular approaches for analysing connectivity, capturing +# correlation between signals in the frequency domain. Various coherency-based +# multivariate methods exist, including: canonical coherency (CaCoh; multivariate +# measure of coherency/coherence); and maximised imaginary coherency (MIC; multivariate +# measure of the imaginary part of coherency). +# +# These methods are described in detail in the following examples: +# - comparison of coherency-based methods - :doc:`../compare_coherency_methods` +# - CaCoh - :doc:`../cacoh` +# - MIC - :doc:`../mic_mim` +# +# The CaCoh and MIC methods work by finding spatial filters that decompose the data into +# components of connectivity, and applying them to the data. With the implementations +# offered in :func:`~mne_connectivity.spectral_connectivity_epochs` and +# :func:`~mne_connectivity.spectral_connectivity_time`, the filters are fit for each +# frequency separately, and the filters are only applied to the same data they are fit +# on. +# +# Unfortunately, fitting filters for each frequency bin can be computationally +# expensive, which may prohibit the use of these techniques, e.g. in real-time BCI +# setups where the rapid analysis of data is paramount, or even in offline analyses +# with huge datasets. +# +# These issues are addressed by the :class:`~mne_connectivity.decoding.CaCoh` and +# :class:`~mne_connectivity.decoding.MIC` decomposition classes of the decoding module. +# Here, the filters are fit for a given frequency band collectively (not each frequency +# bin!) and are stored, allowing them to be applied to the same data they were fit on +# (e.g. for offline analyses of huge datasets) or to new data (e.g. for online analyses +# of streamed data). +# +# In this example, we show how the tools of the decoding module compare to the standard +# ``spectral_connectivity_...()`` functions in terms of their run time, and their +# ability to decompose data into connectivity components. + +######################################################################################## +# Case 1: Fitting to and transforming different data +# -------------------------------------------------- +# +# We start by simulating some connectivity between two groups of signals at 15-20 Hz as +# 60 two-second-long epochs. To demonstrate the approach of fitting filters to one set +# of data and applying to another set of data, we will treat the first 30 epochs as the +# data on which we train the filters, and the last 30 epochs as the data we transform. +# We will use the CaCoh method, since zero time-lag interactions are not present (See +# :doc:`../compare_coherency_methods` for more information). + +# %% + +N_SEEDS = 10 +N_TARGETS = 15 + +FMIN = 15 +FMAX = 20 + +N_EPOCHS = 60 + +epochs = make_signals_in_freq_bands( + n_seeds=N_SEEDS, + n_targets=N_TARGETS, + freq_band=(FMIN, FMAX), + n_epochs=N_EPOCHS, + n_times=200, + sfreq=100, + snr=0.2, + rng_seed=44, +) + +indices = (np.arange(N_SEEDS), np.arange(N_TARGETS) + N_SEEDS) + +######################################################################################## +# First, we use the standard CaCoh approach in +# :func:`~mne_connectivity.spectral_connectivity_epochs` to visualise the connectivity +# in the first 30 epochs. We also plot bivariate coherence to demonstrate the +# signal-to-noise enhancements this multivariate approach offers. As expected, we see a +# peak in connectivity at 15-20 Hz decomposed by the spatial filters. + +# %% + +# Connectivity profile of first 30 epochs (filters fit to these epochs) +con_cacoh_first = spectral_connectivity_epochs( + epochs.get_data(item=np.arange(N_EPOCHS // 2)), + method="cacoh", + indices=([indices[0]], [indices[1]]), + fmin=5, + fmax=35, + sfreq=epochs.info["sfreq"], + rank=([3], [3]), +) +ax = plt.subplot(111) +ax.plot(con_cacoh_first.freqs, np.abs(con_cacoh_first.get_data()[0]), label="CaCoh") + +# Connectivity profile of first 30 epochs (no filters) +con_coh_first = spectral_connectivity_epochs( + epochs.get_data(item=np.arange(N_EPOCHS // 2)), + method="coh", + indices=seed_target_indices(indices[0], indices[1]), + fmin=5, + fmax=35, + sfreq=epochs.info["sfreq"], +) +ax.plot(con_coh_first.freqs, np.mean(con_coh_first.get_data(), axis=0), label="Coh") +ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") +ax.set_xlabel("Frequency (Hz)") +ax.set_ylabel("Connectivity (A.U.)") +ax.set_title("Epochs 0-30") +plt.legend() +plt.show() + +######################################################################################## +# The goal of the decoding module approach is to use the information from the first 30 +# epochs to fit the filters, and then use these filters to extract the same components +# from the last 30 epochs. +# +# For this, we instantiate the :class:`~mne_connectivity.decoding.CaCoh` class with: the +# information about the data being fit/transformed (using an :class:`~mne.Info` object); +# the frequency band of the components we want to decompose (here 15-20 Hz); and the +# channel indices of the seeds and targets. +# +# Next, we call the :meth:`~mne_connectivity.decoding.CaCoh.fit` method, passing in the +# first 30 epochs of data we want to fit the filters to. Once the filters are fit, we +# can apply them to the last 30 epochs using the +# :meth:`~mne_connectivity.decoding.CaCoh.transform` method. +# +# The transformed data has shape ``(epochs x components*2 x times)``, where the new +# 'channels' are organised as the seed components, then target components. For +# convenience, the :meth:`~mne_connectivity.decoding.CaCoh.get_transformed_indices` +# method can be used to get the ``indices`` of the transformed data for use in the +# ``spectral_connectivity_...()`` functions. + +# %% + +# Fit filters to first 30 epochs +cacoh = CaCoh(info=epochs.info, fmin=FMIN, fmax=FMAX, indices=indices, rank=(3, 3)) +cacoh.fit(epochs.get_data(item=np.arange(N_EPOCHS // 2))) + +# Use filters to transform data from last 30 epochs +epochs_transformed = cacoh.transform( + epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)) +) +indices_transformed = cacoh.get_transformed_indices() + +######################################################################################## +# We can now visualise the connectivity in the last 30 epochs of the transformed data, +# which for reference we will compare to connectivity in the last 30 epochs using +# filters fit to the data itself, as well as bivariate coherence to again demonstrate +# the signal-to-noise enhancements the multivariate approach offers. +# +# To compute connectivity of the transformed data, it is simply a case of passing to the +# ``spectral_connectivity_...()`` functions: the transformed data; the indices +# returned from :meth:`~mne_connectivity.decoding.CaCoh.get_transformed_indices`; and +# the corresponding bivariate method (``"coh"`` and ``"cohy"`` for CaCoh; ``"imcoh"`` +# for MIC). +# +# As you can see, the connectivity profile of the transformed data using filters fit on +# the first 30 epochs is very similar to the connectivity profile when using filters fit +# on the last 30 epochs. This shows that the filters are generalisable, able to extract +# the same components of connectivity which they were trained on from new data. + +# %% + +# Connectivity profile of last 30 epochs (filters fit to these epochs) +con_cacoh_last = spectral_connectivity_epochs( + epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)), + method="cacoh", + indices=([indices[0]], [indices[1]]), + fmin=5, + fmax=35, + sfreq=epochs.info["sfreq"], + rank=([3], [3]), +) +ax = plt.subplot(111) +ax.plot( + con_cacoh_last.freqs, + np.abs(con_cacoh_last.get_data()[0]), + label="CaCoh (filters trained\non epochs 30-60)", +) + +# Connectivity profile of last 30 epochs (no filters) +con_coh_last = spectral_connectivity_epochs( + epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)), + method="coh", + indices=seed_target_indices(indices[0], indices[1]), + fmin=5, + fmax=35, + sfreq=epochs.info["sfreq"], +) +ax.plot( + con_coh_last.freqs, np.mean(np.abs(con_coh_last.get_data()), axis=0), label="Coh" +) + +# Connectivity profile of last 30 epochs (filters fit to first 30 epochs) +con_cacoh_last_from_first = spectral_connectivity_epochs( + epochs_transformed, + method="coh", + indices=indices_transformed, + fmin=5, + fmax=35, + sfreq=epochs.info["sfreq"], +) +ax.plot( + con_cacoh_last_from_first.freqs, + np.abs(con_cacoh_last_from_first.get_data()[0]), + label="CaCoh (filters trained\non epochs 0-30)", +) +ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") +ax.set_xlabel("Frequency (Hz)") +ax.set_ylabel("Connectivity (A.U.)") +ax.set_title("Epochs 30-60") +plt.legend() +plt.show() + +######################################################################################## +# In addition to assessing the validity of the approach, we can also look at the time +# taken to run the analysis. Below we present a scenario resembling an online sliding +# window approach typical of a BCI system. We consider the first 30 epochs to be the +# training data that the filters should be fit to, and the last 30 epochs to be the +# windows of data that the filters should be applied to, transforming and computing the +# connectivity of each window (epoch) of data sequentially. +# +# Doing so, we see that once the filters have been fit, it takes only a few milliseconds +# to transform each window of data and compute its connectivity. + +# %% + +cacoh = CaCoh(info=epochs.info, fmin=FMIN, fmax=FMAX, indices=indices, rank=(3, 3)) + +# Time fitting of filters +start_fit = time.time() +cacoh.fit(epochs.get_data(item=np.arange(N_EPOCHS // 2))) +fit_duration = (time.time() - start_fit) * 1000 + +# Time transforming data of each epoch iteratively +start_transform = time.time() +for epoch in epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)): + epoch_transformed = cacoh.transform(epoch) + spectral_connectivity_epochs( + np.expand_dims(epoch_transformed, axis=0), + method="coh", + indices=indices_transformed, + fmin=5, + fmax=35, + sfreq=epochs.info["sfreq"], + ) +transform_duration = (time.time() - start_transform) * 1000 + +# %% + +print(f"Time to fit filters: {fit_duration:.0f} ms") +print(f"Time to transform data and compute connectivity: {transform_duration:.0f} ms") +print(f"Total time: {fit_duration + transform_duration:.0f} ms") + +print( + "\nTime to transform data and compute connectivity per epoch (window): ", + f"{transform_duration/(N_EPOCHS//2):.0f} ms", +) + +######################################################################################## +# In contrast, here we follow the same sequential window approach, but fit filters to +# each window separately rather than using a pre-computed set. Naturally, the process of +# fitting and transforming the data for each window is considerably slower. + +# %% + +# Time fitting and transforming data of each epoch iteratively +start_fit_transform = time.time() +for epoch in epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)): + spectral_connectivity_epochs( + np.expand_dims(epoch, axis=0), + method="cacoh", + indices=([indices[0]], [indices[1]]), + fmin=5, + fmax=35, + sfreq=epochs.info["sfreq"], + rank=([3], [3]), + ) +fit_transform_duration = (time.time() - start_fit_transform) * 1000 + +# %% + +print( + f"Time to fit, transform, and compute connectivity: {fit_transform_duration:.0f} ms" +) + +print( + "\nTime to fit, transform, and compute connectivity per epoch (window): ", + f"{fit_transform_duration/(N_EPOCHS//2):.0f} ms", +) + +######################################################################################## +# As a side note, it is important to consider that a multivariate approache may be as +# fast or even faster than a bivariate approach, depending on the number of connections +# and degree of rank subspace projection being performed. + +# %% + +# Time transforming data of each epoch iteratively +start = time.time() +for epoch in epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)): + spectral_connectivity_epochs( + np.expand_dims(epoch, axis=0), + method="coh", + indices=seed_target_indices(indices[0], indices[1]), + fmin=5, + fmax=35, + sfreq=epochs.info["sfreq"], + ) +duration = (time.time() - start) * 1000 + +# %% + +print(f"Time to compute connectivity: {duration:.0f} ms") + +print( + "\nTime to compute connectivity per epoch (window): ", + f"{duration/(N_EPOCHS//2):.0f} ms", +) + +######################################################################################## +# Case 2: Fitting to and transforming the same data +# ------------------------------------------------- +# +# As mentioned above, the decoding module classes can also be used to transform the same +# data the filters are fit to. This is a similar process to that of the +# ``spectral_connectivity_...()`` functions, but with the increased efficiency of +# fitting filters to a single frequency band as opposed to each frequency bin. +# +# To demonstrate this approach, we will load some example MEG data and divide it into +# two-second-long epochs. We designate the left hemisphere sensors as the seeds and the +# right hemisphere sensors as the targets. Since this is sensor-space data, we will use +# the MIC method to analyse connectivity given its resilience to zero time-lag +# interactions (See :doc:`../compare_coherency_methods` for more information). + +# %% + +raw = mne.io.read_raw_ctf(data_path() / "SubjectCMC.ds") +raw.pick("mag") +raw.crop(50.0, 110.0).load_data() +raw.notch_filter(50) +raw.resample(100) + +epochs = make_fixed_length_epochs(raw, duration=2.0).load_data() + +# left hemisphere sensors +seeds = [idx for idx, ch_info in enumerate(epochs.info["chs"]) if ch_info["loc"][0] < 0] +# right hemisphere sensors +targets = [ + idx for idx, ch_info in enumerate(epochs.info["chs"]) if ch_info["loc"][0] > 0 +] + +######################################################################################## +# There are two equivalent options for fitting and transforming the same data: 1) +# passing the data to the :meth:`~mne_connectivity.decoding.MIC.fit` and +# :meth:`~mne_connectivity.decoding.MIC.transform` methods sequentially; or 2) using the +# combined :meth:`~mne_connectivity.decoding.MIC.fit_transform` method. +# +# We use the latter approach below, fitting the filters to the 15-20 Hz band and using +# the ``"imcoh"`` method in the call to the ``spectral_connectivity_...()`` functions. +# Plotting the results, we see a peak in connectivity at 15-20 Hz. + +# %% + +mic = MIC(info=epochs.info, fmin=FMIN, fmax=FMAX, indices=(seeds, targets), rank=(3, 3)) + +start = time.time() +epochs_transformed = mic.fit_transform(epochs.get_data()) + +con_mic_class = spectral_connectivity_epochs( + epochs_transformed, + method="imcoh", + indices=mic.get_transformed_indices(), + fmin=5, + fmax=30, + sfreq=epochs.info["sfreq"], +) +class_duration = time.time() - start + +ax = plt.subplot(111) +ax.plot( + con_mic_class.freqs, + np.abs(con_mic_class.get_data()[0]), + color=plt.rcParams["axes.prop_cycle"].by_key()["color"][2], + label="Decomposition class", +) +ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") +ax.set_xlabel("Frequency (Hz)") +ax.set_ylabel("Connectivity (A.U.)") +ax.set_title("MIC") +plt.legend() +plt.show() + +######################################################################################## +# For comparison, we can also use the standard approach of the +# ``spectral_connectivity_...()`` functions, which shows a very similar connectivity +# profile in the 15-20 Hz frequency range. Bivariate coherence is again shown to +# demonstrate the signal-to-noise enhancements the multivariate approach offers. + +# %% + +start = time.time() +con_mic_func = spectral_connectivity_epochs( + epochs.get_data(), + method="mic", + indices=([seeds], [targets]), + fmin=5, + fmax=30, + sfreq=epochs.info["sfreq"], + rank=([3], [3]), +) +func_duration = time.time() - start + +con_imcoh = spectral_connectivity_epochs( + epochs.get_data(), + method="imcoh", + indices=seed_target_indices(seeds, targets), + fmin=5, + fmax=30, + sfreq=epochs.info["sfreq"], + rank=([3], [3]), +) + +ax = plt.subplot(111) +ax.plot( + con_mic_func.freqs, + np.abs(con_mic_func.get_data()[0]), + label="MIC (standard\nfunction)", +) +ax.plot( + con_imcoh.freqs, + np.mean(np.abs(con_imcoh.get_data()), axis=0), + label="ImCoh", +) +ax.plot( + con_mic_class.freqs, + np.abs(con_mic_class.get_data()[0]), + label="MIC (decomposition\nclass)", +) +ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") +ax.set_xlabel("Frequency (Hz)") +ax.set_ylabel("Connectivity (A.U.)") +ax.set_title("MIC") +plt.legend() +plt.show() + +######################################################################################## +# As with the previous example, we can also compare the time taken to run the analyses. +# Here we see that the decomposition class is much faster than the +# ``spectral_connectivity_...()`` functions, thanks to the fact that the filters are fit +# to the entire frequency band and not each frequency bin. + +# %% + +print( + "Time to fit, transform, and compute connectivity (decomposition class): " + f"{class_duration:.2f} s" +) +print( + f"Time to fit, transform, and compute connectivity (standard function): " + f"{func_duration:.2f} s" +) + +######################################################################################## +# Limitations +# ----------- +# Finally, it is important to discuss a key limitation of the decoding module approach: +# the need to define a specific frequency band. Defining this band requires some +# existing knowledge about your data or the oscillatory activity you are studying. This +# insight may come from a pilot study where a frequency band of interest was identified, +# a canonical frequency band defined in the literature, etc... In contrast, by fitting +# filters to each frequency bin, the standard ``spectral_connectivity_...()`` functions +# are more flexible. +# +# Additionally, by applying filters fit on one set of data to another, you are assuming +# that the connectivity components the filters are designed to extract are consistent +# across the two sets of data. However, this may not be the case if you are applying the +# filters to data from a distinct functional state where the spatial distribution of the +# components differs. Again, by fitting filters to each new set of data passed in, the +# standard ``spectral_connectivity_...()`` functions are more flexible, extracting +# whatever connectivity components are present in that data. +# +# On these points, we note that the ``spectral_connectivity_...()`` functions complement +# the decoding module classes well, offering a tool by which to explore your data to: +# identify possible frequency bands of interest; and identify the spatial distributions +# of connectivity components to determine if they are consistent across different +# portions of the data. +# +# Ultimately, there are distinct advantages and disadvantages to both approaches, and +# one may be more suitable than the other depending on your use case. + +# %% From a3bf25346902f56f730341bfff8cc997bdef5e2b Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 5 Jun 2024 14:24:32 +0200 Subject: [PATCH 14/38] Switch to cleaner epoch indexing --- examples/decoding/cohy_decomposition.py | 32 ++++++++++--------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/cohy_decomposition.py index c3c57285..9a43a4e5 100644 --- a/examples/decoding/cohy_decomposition.py +++ b/examples/decoding/cohy_decomposition.py @@ -121,12 +121,11 @@ # Connectivity profile of first 30 epochs (filters fit to these epochs) con_cacoh_first = spectral_connectivity_epochs( - epochs.get_data(item=np.arange(N_EPOCHS // 2)), + epochs[: N_EPOCHS // 2], method="cacoh", indices=([indices[0]], [indices[1]]), fmin=5, fmax=35, - sfreq=epochs.info["sfreq"], rank=([3], [3]), ) ax = plt.subplot(111) @@ -134,12 +133,11 @@ # Connectivity profile of first 30 epochs (no filters) con_coh_first = spectral_connectivity_epochs( - epochs.get_data(item=np.arange(N_EPOCHS // 2)), + epochs[: N_EPOCHS // 2], method="coh", indices=seed_target_indices(indices[0], indices[1]), fmin=5, fmax=35, - sfreq=epochs.info["sfreq"], ) ax.plot(con_coh_first.freqs, np.mean(con_coh_first.get_data(), axis=0), label="Coh") ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") @@ -174,12 +172,10 @@ # Fit filters to first 30 epochs cacoh = CaCoh(info=epochs.info, fmin=FMIN, fmax=FMAX, indices=indices, rank=(3, 3)) -cacoh.fit(epochs.get_data(item=np.arange(N_EPOCHS // 2))) +cacoh.fit(epochs[: N_EPOCHS // 2].get_data()) # Use filters to transform data from last 30 epochs -epochs_transformed = cacoh.transform( - epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)) -) +epochs_transformed = cacoh.transform(epochs[N_EPOCHS // 2 :].get_data()) indices_transformed = cacoh.get_transformed_indices() ######################################################################################## @@ -203,12 +199,11 @@ # Connectivity profile of last 30 epochs (filters fit to these epochs) con_cacoh_last = spectral_connectivity_epochs( - epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)), + epochs[N_EPOCHS // 2 :], method="cacoh", indices=([indices[0]], [indices[1]]), fmin=5, fmax=35, - sfreq=epochs.info["sfreq"], rank=([3], [3]), ) ax = plt.subplot(111) @@ -220,12 +215,11 @@ # Connectivity profile of last 30 epochs (no filters) con_coh_last = spectral_connectivity_epochs( - epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)), + epochs[N_EPOCHS // 2 :], method="coh", indices=seed_target_indices(indices[0], indices[1]), fmin=5, fmax=35, - sfreq=epochs.info["sfreq"], ) ax.plot( con_coh_last.freqs, np.mean(np.abs(con_coh_last.get_data()), axis=0), label="Coh" @@ -269,12 +263,12 @@ # Time fitting of filters start_fit = time.time() -cacoh.fit(epochs.get_data(item=np.arange(N_EPOCHS // 2))) +cacoh.fit(epochs[: N_EPOCHS // 2].get_data()) fit_duration = (time.time() - start_fit) * 1000 # Time transforming data of each epoch iteratively start_transform = time.time() -for epoch in epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)): +for epoch in epochs[N_EPOCHS // 2 :]: epoch_transformed = cacoh.transform(epoch) spectral_connectivity_epochs( np.expand_dims(epoch_transformed, axis=0), @@ -306,7 +300,7 @@ # Time fitting and transforming data of each epoch iteratively start_fit_transform = time.time() -for epoch in epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)): +for epoch in epochs[N_EPOCHS // 2 :]: spectral_connectivity_epochs( np.expand_dims(epoch, axis=0), method="cacoh", @@ -338,7 +332,7 @@ # Time transforming data of each epoch iteratively start = time.time() -for epoch in epochs.get_data(item=np.arange(N_EPOCHS // 2, N_EPOCHS)): +for epoch in epochs[N_EPOCHS // 2 :]: spectral_connectivity_epochs( np.expand_dims(epoch, axis=0), method="coh", @@ -441,23 +435,21 @@ start = time.time() con_mic_func = spectral_connectivity_epochs( - epochs.get_data(), + epochs, method="mic", indices=([seeds], [targets]), fmin=5, fmax=30, - sfreq=epochs.info["sfreq"], rank=([3], [3]), ) func_duration = time.time() - start con_imcoh = spectral_connectivity_epochs( - epochs.get_data(), + epochs, method="imcoh", indices=seed_target_indices(seeds, targets), fmin=5, fmax=30, - sfreq=epochs.info["sfreq"], rank=([3], [3]), ) From ea48ce30e806b2728d6e93fd742a8d2bda48c6f0 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 5 Jun 2024 14:24:50 +0200 Subject: [PATCH 15/38] Fix spelling error --- examples/decoding/cohy_decomposition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/cohy_decomposition.py index 9a43a4e5..8cf635d1 100644 --- a/examples/decoding/cohy_decomposition.py +++ b/examples/decoding/cohy_decomposition.py @@ -324,7 +324,7 @@ ) ######################################################################################## -# As a side note, it is important to consider that a multivariate approache may be as +# As a side note, it is important to consider that a multivariate approach may be as # fast or even faster than a bivariate approach, depending on the number of connections # and degree of rank subspace projection being performed. From 9bcff587283c7b5a1936ea7e4eba3ab6d127d5b5 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 5 Jun 2024 14:25:33 +0200 Subject: [PATCH 16/38] Update error checking --- mne_connectivity/decoding/decomposition.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index d6311c57..b10bcd2f 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -84,9 +84,9 @@ def __init__( if fmax > info["sfreq"] / 2: raise ValueError("`fmax` cannot be larger than the Nyquist frequency") - _validate_type(indices, tuple, "`indices`", "tuple of array-like") + _validate_type(indices, tuple, "`indices`", "tuple of array-likes") if len(indices) != 2: - raise ValueError("`indices` must be have length 2") + raise ValueError("`indices` must have length 2") for indices_group in indices: _validate_type( indices_group, @@ -119,7 +119,7 @@ def __init__( _validate_type(rank, (tuple, None), "`rank`", "tuple of ints or None") if rank is not None: if len(rank) != 2: - raise ValueError("`rank` must be have length 2") + raise ValueError("`rank` must have length 2") for rank_group in rank: _validate_type(rank_group, int, "`rank`", "tuple of ints or None") _rank = self._check_rank(rank, indices) @@ -155,7 +155,7 @@ def _check_indices(self, indices, n_chans): max_idx = np.max(indices.compressed()) if max_idx + 1 > n_chans: raise ValueError( - "At least one entry in `indices` is greater than the number of " + "at least one entry in `indices` is greater than the number of " "channels in `info`" ) @@ -167,10 +167,14 @@ def _check_rank(self, rank, indices): # convert to multivariate format rank = ([rank[0]], [rank[1]]) + # make sure ranks are > 0 + if np.any(np.array(rank) <= 0): + raise ValueError("entries of `rank` must be > 0") + # find whether entries of rank exceed number of channels in indices if rank[0][0] > len(indices[0]) or rank[1][0] > len(indices[1]): raise ValueError( - "At least one entry in `rank` is greater than the number of " + "at least one entry in `rank` is greater than the number of " "seed/target channels in `indices`" ) From 6c8ae9732eb17ab14a9af58f8a31650db77c92ce Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 6 Jun 2024 11:53:45 +0200 Subject: [PATCH 17/38] Bug fix indices setter wrong format --- mne_connectivity/decoding/decomposition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index b10bcd2f..1f1eacfe 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -41,7 +41,9 @@ def indices(self): @indices.setter def indices(self, indices): """Set ``indices`` parameter using the input format.""" - self._indices = (np.array([indices[0]]), np.array([indices[1]])) + self._indices = _check_multivariate_indices( + ([indices[0]], [indices[1]]), self.info["nchan"] + ) @property def rank(self): From 55b14ab9fe9f4ceae9c2a4c857fc07fe2ca18339 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 6 Jun 2024 11:57:49 +0200 Subject: [PATCH 18/38] Add unit tests --- .../decoding/tests/test_decomposition.py | 453 ++++++++++++++++++ 1 file changed, 453 insertions(+) create mode 100644 mne_connectivity/decoding/tests/test_decomposition.py diff --git a/mne_connectivity/decoding/tests/test_decomposition.py b/mne_connectivity/decoding/tests/test_decomposition.py new file mode 100644 index 00000000..80026826 --- /dev/null +++ b/mne_connectivity/decoding/tests/test_decomposition.py @@ -0,0 +1,453 @@ +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from mne_connectivity import ( + make_signals_in_freq_bands, + seed_target_indices, + spectral_connectivity_epochs, +) +from mne_connectivity.decoding import MIC, CaCoh +from mne_connectivity.utils import _check_multivariate_indices + + +@pytest.mark.parametrize("DecompClass", [CaCoh, MIC]) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) +def test_spectral_decomposition(DecompClass, mode): + """Test spectral decomposition classes run and give expected results.""" + # SIMULATE DATA + # Settings + n_seeds = 3 + n_targets = 3 + n_signals = n_seeds + n_targets + n_epochs = 60 + trans_bandwidth = 1 + + # Get data with connectivity to optimise (~90° angle good for MIC & CaCoh) + fmin_optimise = 11 + fmax_optimise = 14 + epochs_optimise = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=(fmin_optimise, fmax_optimise), + n_epochs=n_epochs, + trans_bandwidth=trans_bandwidth, + snr=0.5, + connection_delay=10, # ~90° interaction angle for this freq. band + rng_seed=44, + ) + + # Get data with connectivity to ignore + fmin_ignore = 21 + fmax_ignore = 24 + epochs_ignore = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=(fmin_ignore, fmax_ignore), + n_epochs=n_epochs, + trans_bandwidth=trans_bandwidth, + snr=0.5, + connection_delay=6, # ~90° interaction angle for this freq. band + rng_seed=42, + ) + + # Combine data and get indices + epochs = epochs_optimise.add_channels([epochs_ignore]) + seeds = np.concatenate((np.arange(n_seeds), np.arange(n_seeds) + n_signals)) + targets = np.concatenate( + (np.arange(n_targets) + n_seeds, np.arange(n_targets) + n_signals + n_seeds) + ) + indices = (seeds, targets) + + bivariate_method = "coh" if DecompClass == CaCoh else "imcoh" + multivariate_method = "cacoh" if DecompClass == CaCoh else "mic" + + cwt_freq_resolution = 0.5 + cwt_freqs = np.arange(5, 30, cwt_freq_resolution) + cwt_n_cycles = 6 + + # TEST FITTING AND TRANSFORMING SAME DATA EXTRACTS CONNECTIVITY + decomp_class = DecompClass( + info=epochs.info, + fmin=fmin_optimise, + fmax=fmax_optimise, + indices=indices, + mode=mode, + cwt_freq_resolution=cwt_freq_resolution, + cwt_n_cycles=cwt_n_cycles, + ) + epochs_transformed = decomp_class.fit_transform( + X=epochs[: n_epochs // 2].get_data() + ) + con_mv_class = spectral_connectivity_epochs( + epochs_transformed, + method=bivariate_method, + indices=decomp_class.get_transformed_indices(), + sfreq=epochs.info["sfreq"], + mode=mode, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + ) + con_mv_func = spectral_connectivity_epochs( + epochs[: n_epochs // 2], + method=multivariate_method, + indices=([seeds], [targets]), + mode=mode, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + ) + con_bv_func = spectral_connectivity_epochs( + epochs[: n_epochs // 2], + method=bivariate_method, + indices=seed_target_indices(seeds, targets), + mode=mode, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + ) + + # Frequencies of interest + freqs = np.array(con_mv_class.freqs) + freqs_optimise = (freqs >= fmin_optimise) & (freqs <= fmax_optimise) + freqs_ignore = (freqs >= fmin_ignore) & (freqs <= fmax_ignore) + + # Thresholds for checking validity of connectivity (work across all modes) + optimisation_diff = 0.35 # optimisation causes big increase in connectivity + similarity_thresh = 0.15 # freqs. being optimised or ignored should be very similar + + # Test selective optimisation of desired freq. band vs. no optimisation + assert ( + np.abs(con_mv_class.get_data()[0, freqs_optimise]).mean() + > np.abs(con_bv_func.get_data()[:, freqs_optimise]).mean() + optimisation_diff + ) # check connectivity for optimised freq. band higher than without optimisation + assert_allclose( + np.abs(con_mv_class.get_data()[0, freqs_ignore]).mean(), + np.abs(con_bv_func.get_data()[:, freqs_ignore]).mean(), + atol=similarity_thresh, + ) # check connectivity for ignored freq. band similar to no optimisation + + # Test band-wise optimisation similar to bin-wise optimisation + assert_allclose( + np.abs(con_mv_class.get_data()[0, freqs_optimise]).mean(), + np.abs(con_mv_func.get_data()[0, freqs_optimise]).mean(), + atol=similarity_thresh, + ) # check connectivity for optimised freq. band similar for both versions + assert ( + np.abs(con_mv_class.get_data()[0, freqs_ignore]).mean() + < np.abs(con_mv_func.get_data()[0, freqs_ignore]).mean() - optimisation_diff + ) # check connectivity for ignored freq. band lower than with optimisation + + # Test `fit_transform` equivalent to `fit` and `transform` separately + if mode == "multitaper": # only need to test once + decomp_class_2 = DecompClass( + info=epochs.info, + fmin=fmin_optimise, + fmax=fmax_optimise, + indices=indices, + mode=mode, + ) + decomp_class_2.fit(X=epochs[: n_epochs // 2].get_data()) + epochs_transformed_2 = decomp_class_2.transform( + X=epochs[: n_epochs // 2].get_data() + ) + assert_allclose(epochs_transformed, epochs_transformed_2) + assert_allclose(decomp_class.filters_, decomp_class_2.filters_) + assert_allclose(decomp_class.patterns_, decomp_class_2.patterns_) + + # TEST FITTING ON ONE PIECE OF DATA AND TRANSFORMING ANOTHER + con_mv_class_unseen_data = spectral_connectivity_epochs( + decomp_class.transform(X=epochs[n_epochs // 2 :].get_data()), + method=bivariate_method, + indices=decomp_class.get_transformed_indices(), + sfreq=epochs.info["sfreq"], + mode=mode, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + ) + assert_allclose( + np.abs(con_mv_class.get_data()[0, freqs_optimise]).mean(), + np.abs(con_mv_class_unseen_data.get_data()[0, freqs_optimise]).mean(), + atol=similarity_thresh, + ) # check connectivity for optimised freq. band similarly high for seen & unseen + assert_allclose( + np.abs(con_mv_class.get_data()[0, freqs_ignore]).mean(), + np.abs(con_mv_class_unseen_data.get_data()[0, freqs_ignore]).mean(), + atol=similarity_thresh, + ) # check connectivity for optimised freq. band similarly low for seen & unseen + + # TEST GETTERS & SETTERS + # Test indices internal storage and returned format + if mode == "multitaper": # only need to test once + assert np.all(np.array(decomp_class.indices) == np.array((seeds, targets))) + assert np.all( + decomp_class._indices + == _check_multivariate_indices(([seeds], [targets]), n_signals) + ) + decomp_class.set_params(indices=(targets, seeds)) + assert np.all(np.array(decomp_class.indices) == np.array((targets, seeds))) + assert np.all( + decomp_class._indices + == _check_multivariate_indices(([targets], [seeds]), n_signals) + ) + + # Test rank internal storage and returned format + assert np.all(decomp_class.rank == (n_signals, n_signals)) + assert np.all(decomp_class._rank == ([n_signals], [n_signals])) + decomp_class.set_params(rank=(1, 2)) + assert np.all(decomp_class.rank == (1, 2)) + assert np.all(decomp_class._rank == ([1], [2])) + + +@pytest.mark.parametrize("DecompClass", [CaCoh, MIC]) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) +def test_spectral_decomposition_parallel(DecompClass, mode): + """Test spectral decomposition classes run with parallelisation.""" + # SIMULATE DATA + n_seeds = 3 + n_targets = 3 + fmin = 10 + fmax = 15 + epochs = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=(fmin, fmax), + snr=0.5, + rng_seed=44, + ) + + # RUN DECOMPOSITION + decomp_class = DecompClass( + info=epochs.info, + fmin=fmin, + fmax=fmax, + indices=(np.arange(n_seeds), np.arange(n_targets) + n_seeds), + mode=mode, + cwt_freq_resolution=1, + cwt_n_cycles=6, + n_jobs=2, # use parallelisation + ) + decomp_class.fit_transform(X=epochs.get_data()) + + +@pytest.mark.parametrize("DecompClass", [CaCoh, MIC]) +def test_spectral_decomposition_error_catch(DecompClass): + """Test error catching for spectral decomposition classes.""" + # SIMULATE DATA + n_seeds = 3 + n_targets = 3 + fmin = 15 + fmax = 20 + epochs = make_signals_in_freq_bands( + n_seeds=n_seeds, n_targets=n_targets, freq_band=(fmin, fmax), rng_seed=44 + ) + indices = (np.arange(n_seeds), np.arange(n_targets) + n_seeds) + + # TEST BAD INITIALISATION + # Test info + with pytest.raises(TypeError, match="`info` must be an instance of mne.Info"): + DecompClass(info="info", fmin=fmin, fmax=fmax, indices=indices) + + # Test fmin & fmax + with pytest.raises(TypeError, match="`fmin` must be an instance of int or float"): + DecompClass(info=epochs.info, fmin="15", fmax=fmax, indices=indices) + with pytest.raises(TypeError, match="`fmax` must be an instance of int or float"): + DecompClass(info=epochs.info, fmin=fmin, fmax="20", indices=indices) + with pytest.raises(ValueError, match="`fmax` must be larger than `fmin`"): + DecompClass(info=epochs.info, fmin=fmax, fmax=fmin, indices=indices) + with pytest.raises( + ValueError, match="`fmax` cannot be larger than the Nyquist frequency" + ): + DecompClass( + info=epochs.info, + fmin=fmin, + fmax=epochs.info["sfreq"] / 2 + 1, + indices=indices, + ) + + # Test indices + with pytest.raises( + TypeError, match="`indices` must be an instance of tuple of array-likes" + ): + DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=list(indices)) + with pytest.raises( + TypeError, match="`indices` must be an instance of tuple of array-likes" + ): + DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=(0, 1)) + with pytest.raises(ValueError, match="`indices` must have length 2"): + DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=(indices[0],)) + with pytest.raises( + ValueError, + match=( + "multivariate indices cannot contain repeated channels within a seed or " + "target" + ), + ): + DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=([0, 0], [1, 2])) + with pytest.raises( + ValueError, + match=( + "multivariate indices cannot contain repeated channels within a seed or " + "target" + ), + ): + DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=([0, 1], [2, 2])) + with pytest.raises( + ValueError, match="a negative channel index is not present in the data" + ): + DecompClass( + info=epochs.info, + fmin=fmin, + fmax=fmax, + indices=([0], [(n_seeds + n_targets) * -1]), + ) + with pytest.raises( + ValueError, + match=( + "at least one entry in `indices` is greater than the number of channels in " + "`info`" + ), + ): + DecompClass( + info=epochs.info, + fmin=fmin, + fmax=fmax, + indices=([0], [n_seeds + n_targets]), + ) + + # Test mode + with pytest.raises(ValueError, match="Invalid value for the 'mode' parameter"): + DecompClass( + info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, mode="notamode" + ) + + # Test multitaper settings + with pytest.raises( + TypeError, match="`mt_bandwidth` must be an instance of int, float, or None" + ): + DecompClass( + info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, mt_bandwidth="5" + ) + with pytest.raises(TypeError, match="`mt_adaptive` must be an instance of bool"): + DecompClass( + info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, mt_adaptive=1 + ) + with pytest.raises(TypeError, match="`mt_low_bias` must be an instance of bool"): + DecompClass( + info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, mt_low_bias=1 + ) + + # Test wavelet settings + with pytest.raises( + TypeError, match="`cwt_freq_resolution` must be an instance of int or float" + ): + DecompClass( + info=epochs.info, + fmin=fmin, + fmax=fmax, + indices=indices, + cwt_freq_resolution="1", + ) + with pytest.raises( + TypeError, + match=( + "`cwt_n_cycles` must be an instance of int, float, or array-like of ints " + "or floats" + ), + ): + DecompClass( + info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, cwt_n_cycles="5" + ) + + # Test n_components + with pytest.raises( + TypeError, match="`n_components` must be an instance of int or None" + ): + DecompClass( + info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, n_components="2" + ) + + # Test rank + with pytest.raises( + TypeError, match="`rank` must be an instance of tuple of ints or None" + ): + DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, rank="2") + with pytest.raises( + TypeError, match="`rank` must be an instance of tuple of ints or None" + ): + DecompClass( + info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, rank=("2", "2") + ) + with pytest.raises(ValueError, match="`rank` must have length 2"): + DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, rank=(2,)) + with pytest.raises(ValueError, match="entries of `rank` must be > 0"): + DecompClass( + info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, rank=(0, 1) + ) + with pytest.raises( + ValueError, + match=( + "at least one entry in `rank` is greater than the number of seed/target " + "channels in `indices`" + ), + ): + DecompClass( + info=epochs.info, + fmin=fmin, + fmax=fmax, + indices=indices, + rank=(n_seeds + 1, n_targets), + ) + with pytest.raises( + ValueError, + match=( + "at least one entry in `rank` is greater than the number of seed/target " + "channels in `indices`" + ), + ): + DecompClass( + info=epochs.info, + fmin=fmin, + fmax=fmax, + indices=indices, + rank=(n_seeds, n_targets + 1), + ) + + # Test n_jobs + with pytest.raises(TypeError, match="`n_jobs` must be an instance of int"): + DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, n_jobs="1") + + # Test verbose + with pytest.raises( + TypeError, match="`verbose` must be an instance of bool, str, int, or None" + ): + DecompClass( + info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, verbose=[True] + ) + + decomp_class = DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=indices) + + # TEST BAD FITTING + # Test input data + with pytest.raises(TypeError, match="`X` must be an instance of NumPy array"): + decomp_class.fit(X=epochs.get_data().tolist()) + with pytest.raises(ValueError, match="Invalid value for the '`X.ndim`' parameter"): + decomp_class.fit(X=epochs.get_data()[0]) + with pytest.raises(ValueError, match="`X` does not match Info"): + decomp_class.fit(X=epochs.get_data()[:, :-1]) + # XXX: Add test for rank of X being <= n_components when n_components can be > 1 + + # TEST TRANSFORM BEFORE FITTING + with pytest.raises( + RuntimeError, + match="no filters are available, please call the `fit` method first", + ): + decomp_class.transform(X=epochs.get_data()) + + decomp_class.fit(X=epochs.get_data()) + + # TEST BAD TRANSFORMING + with pytest.raises(TypeError, match="`X` must be an instance of NumPy array"): + decomp_class.transform(X=epochs.get_data().tolist()) + with pytest.raises(ValueError, match="Invalid value for the '`X.ndim`' parameter"): + decomp_class.transform(X=epochs.get_data()[0, 0]) + with pytest.raises(ValueError, match="`X` does not match Info"): + decomp_class.transform(X=epochs.get_data()[:, :-1]) From 3861da5f668c5f885da65db78e2e77437e2aed8c Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 6 Jun 2024 12:30:06 +0200 Subject: [PATCH 19/38] Update example from review --- examples/decoding/cohy_decomposition.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/cohy_decomposition.py index 8cf635d1..c8f3bd3b 100644 --- a/examples/decoding/cohy_decomposition.py +++ b/examples/decoding/cohy_decomposition.py @@ -295,6 +295,11 @@ # In contrast, here we follow the same sequential window approach, but fit filters to # each window separately rather than using a pre-computed set. Naturally, the process of # fitting and transforming the data for each window is considerably slower. +# +# Furthermore, given the noisy nature of single windows of data, there is a risk of +# overfitting the filters to this noise as opposed to the genuine interaction(s) of +# interest. This risk is mitigated by performing the initial filter fitting on a larger +# set of data. # %% @@ -416,20 +421,20 @@ con_mic_class.freqs, np.abs(con_mic_class.get_data()[0]), color=plt.rcParams["axes.prop_cycle"].by_key()["color"][2], - label="Decomposition class", + label="MIC (decomposition\nclass)", ) ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") ax.set_xlabel("Frequency (Hz)") ax.set_ylabel("Connectivity (A.U.)") -ax.set_title("MIC") plt.legend() plt.show() ######################################################################################## # For comparison, we can also use the standard approach of the # ``spectral_connectivity_...()`` functions, which shows a very similar connectivity -# profile in the 15-20 Hz frequency range. Bivariate coherence is again shown to -# demonstrate the signal-to-noise enhancements the multivariate approach offers. +# profile in the 15-20 Hz frequency range (but not identical due to band- vs. bin-wise +# filter fitting approaches). Bivariate coherence is again shown to demonstrate the +# signal-to-noise enhancements the multivariate approach offers. # %% @@ -472,7 +477,6 @@ ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") ax.set_xlabel("Frequency (Hz)") ax.set_ylabel("Connectivity (A.U.)") -ax.set_title("MIC") plt.legend() plt.show() From 025e6c17cf1cbf00a54b5cc7054ee48943cbcaa1 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 6 Jun 2024 15:12:00 +0200 Subject: [PATCH 20/38] Update cwt_morlet params --- examples/decoding/cohy_decomposition.py | 27 +- mne_connectivity/decoding/decomposition.py | 110 +++-- .../decoding/tests/test_decomposition.py | 407 ++++++++++++------ mne_connectivity/utils/docs.py | 18 +- 4 files changed, 379 insertions(+), 183 deletions(-) diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/cohy_decomposition.py index c8f3bd3b..79923eb2 100644 --- a/examples/decoding/cohy_decomposition.py +++ b/examples/decoding/cohy_decomposition.py @@ -171,7 +171,14 @@ # %% # Fit filters to first 30 epochs -cacoh = CaCoh(info=epochs.info, fmin=FMIN, fmax=FMAX, indices=indices, rank=(3, 3)) +cacoh = CaCoh( + info=epochs.info, + indices=indices, + mode="multitaper", + fmin=FMIN, + fmax=FMAX, + rank=(3, 3), +) cacoh.fit(epochs[: N_EPOCHS // 2].get_data()) # Use filters to transform data from last 30 epochs @@ -259,7 +266,14 @@ # %% -cacoh = CaCoh(info=epochs.info, fmin=FMIN, fmax=FMAX, indices=indices, rank=(3, 3)) +cacoh = CaCoh( + info=epochs.info, + indices=indices, + mode="multitaper", + fmin=FMIN, + fmax=FMAX, + rank=(3, 3), +) # Time fitting of filters start_fit = time.time() @@ -401,7 +415,14 @@ # %% -mic = MIC(info=epochs.info, fmin=FMIN, fmax=FMAX, indices=(seeds, targets), rank=(3, 3)) +mic = MIC( + info=epochs.info, + indices=(seeds, targets), + mode="multitaper", + fmin=FMIN, + fmax=FMAX, + rank=(3, 3), +) start = time.time() epochs_transformed = mic.fit_transform(epochs.get_data()) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index 1f1eacfe..2062d6d5 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -61,14 +61,14 @@ def rank(self, rank): def __init__( self, info, - fmin, - fmax, indices, mode="multitaper", + fmin=None, + fmax=None, mt_bandwidth=None, mt_adaptive=False, mt_low_bias=True, - cwt_freq_resolution=1, + cwt_freqs=None, cwt_n_cycles=7, n_components=None, rank=None, @@ -79,13 +79,6 @@ def __init__( # Validate inputs _validate_type(info, Info, "`info`", "mne.Info") - _validate_type(fmin, (int, float), "`fmin`", "int or float") - _validate_type(fmax, (int, float), "`fmax`", "int or float") - if fmin > fmax: - raise ValueError("`fmax` must be larger than `fmin`") - if fmax > info["sfreq"] / 2: - raise ValueError("`fmax` cannot be larger than the Nyquist frequency") - _validate_type(indices, tuple, "`indices`", "tuple of array-likes") if len(indices) != 2: raise ValueError("`indices` must have length 2") @@ -99,22 +92,52 @@ def __init__( _indices = self._check_indices(indices, info["nchan"]) _check_option("mode", mode, ("multitaper", "fourier", "cwt_morlet")) - - _validate_type( - mt_bandwidth, (int, float, None), "`mt_bandwidth`", "int, float, or None" - ) - _validate_type(mt_adaptive, bool, "`mt_adaptive`", "bool") - _validate_type(mt_low_bias, bool, "`mt_low_bias`", "bool") - - _validate_type( - cwt_freq_resolution, (int, float), "`cwt_freq_resolution`", "int or float" - ) - _validate_type( - cwt_n_cycles, - (int, float, tuple, list, np.ndarray), - "`cwt_n_cycles`", - "int, float, or array-like of ints or floats", - ) + if mode in ["multitaper", "fourier"]: + if fmin is None or fmax is None: + raise TypeError( + "`fmin` and `fmax` must not be None if `mode` is 'multitaper' or " + "'fourier'" + ) + _validate_type(fmin, (int, float), "`fmin`", "int or float") + _validate_type(fmax, (int, float), "`fmax`", "int or float") + if fmin > fmax: + raise ValueError("`fmax` must be larger than `fmin`") + if fmax > info["sfreq"] / 2: + raise ValueError("`fmax` cannot be larger than the Nyquist frequency") + if mode == "multitaper": + _validate_type( + mt_bandwidth, + (int, float, None), + "`mt_bandwidth`", + "int, float, or None", + ) + _validate_type(mt_adaptive, bool, "`mt_adaptive`", "bool") + _validate_type(mt_low_bias, bool, "`mt_low_bias`", "bool") + else: + if cwt_freqs is None: + raise TypeError( + "`cwt_freqs` must not be None if `mode` is 'cwt_morlet'" + ) + _validate_type( + cwt_freqs, (tuple, list, np.ndarray), "`cwt_freqs`", "array-like" + ) + if cwt_freqs[-1] > info["sfreq"] / 2: + raise ValueError( + "last entry of `cwt_freqs` cannot be larger than the Nyquist " + "frequency" + ) + _validate_type( + cwt_n_cycles, + (int, float, tuple, list, np.ndarray), + "`cwt_n_cycles`", + "int, float, or array-like", + ) + if isinstance(cwt_n_cycles, (tuple, list, np.ndarray)) and len( + cwt_n_cycles + ) != len(cwt_freqs): + raise ValueError( + "`cwt_n_cycles` array-like must have the same length as `cwt_freqs`" + ) _validate_type(n_components, (int, None), "`n_components`", "int or None") @@ -134,14 +157,14 @@ def __init__( # Store inputs self.info = info - self.fmin = fmin - self.fmax = fmax self._indices = _indices # uses getter/setter for public parameter self.mode = mode + self.fmin = fmin + self.fmax = fmax self.mt_bandwidth = mt_bandwidth self.mt_adaptive = mt_adaptive self.mt_low_bias = mt_low_bias - self.cwt_freq_resolution = cwt_freq_resolution + self.cwt_freqs = cwt_freqs self.cwt_n_cycles = cwt_n_cycles self.n_components = 1 # XXX: fixed until n_comps > 1 supported self._rank = _rank # uses getter/setter for public parameter @@ -275,18 +298,17 @@ def _compute_csd(self, X): csd = csd_array_fourier(**csd_kwargs) else: csd_kwargs.update( - { - "frequencies": np.arange( - self.fmin, - self.fmax + self.cwt_freq_resolution, - self.cwt_freq_resolution, - ), - "n_cycles": self.cwt_n_cycles, - } + {"frequencies": self.cwt_freqs, "n_cycles": self.cwt_n_cycles} ) csd = csd_array_morlet(**csd_kwargs) - csd = csd.sum(self.fmin, self.fmax).get_data(index=0) + if self.mode in ["multitaper", "fourier"]: + fmin = self.fmin + fmax = self.fmax + else: + fmin = self.cwt_freqs[0] + fmax = self.cwt_freqs[-1] + csd = csd.sum(fmin, fmax).get_data(index=0) csd = np.reshape(csd, csd.shape[0] ** 2) return np.expand_dims(csd, 1) @@ -409,14 +431,14 @@ class CaCoh(_AbstractDecompositionBase): Parameters ---------- %(info_decoding)s - %(fmin_decoding)s - %(fmax_decoding)s %(indices_decoding)s %(mode)s + %(fmin_decoding)s + %(fmax_decoding)s %(mt_bandwidth)s %(mt_adaptive)s %(mt_low_bias)s - %(cwt_freq_resolution)s + %(cwt_freqs)s %(cwt_n_cycles)s %(n_components)s %(rank)s @@ -460,14 +482,14 @@ class MIC(_AbstractDecompositionBase): Parameters ---------- %(info_decoding)s - %(fmin_decoding)s - %(fmax_decoding)s %(indices_decoding)s %(mode)s + %(fmin_decoding)s + %(fmax_decoding)s %(mt_bandwidth)s %(mt_adaptive)s %(mt_low_bias)s - %(cwt_freq_resolution)s + %(cwt_freqs)s %(cwt_n_cycles)s %(n_components)s %(rank)s diff --git a/mne_connectivity/decoding/tests/test_decomposition.py b/mne_connectivity/decoding/tests/test_decomposition.py index 80026826..21742f0c 100644 --- a/mne_connectivity/decoding/tests/test_decomposition.py +++ b/mne_connectivity/decoding/tests/test_decomposition.py @@ -22,6 +22,8 @@ def test_spectral_decomposition(DecompClass, mode): n_signals = n_seeds + n_targets n_epochs = 60 trans_bandwidth = 1 + fstart = 5 # start computing connectivity + fend = 30 # stop computing connectivity # Get data with connectivity to optimise (~90° angle good for MIC & CaCoh) fmin_optimise = 11 @@ -62,18 +64,18 @@ def test_spectral_decomposition(DecompClass, mode): bivariate_method = "coh" if DecompClass == CaCoh else "imcoh" multivariate_method = "cacoh" if DecompClass == CaCoh else "mic" - cwt_freq_resolution = 0.5 - cwt_freqs = np.arange(5, 30, cwt_freq_resolution) + cwt_freq_res = 0.5 + cwt_freqs = np.arange(fmin_optimise, fmax_optimise + cwt_freq_res, cwt_freq_res) cwt_n_cycles = 6 # TEST FITTING AND TRANSFORMING SAME DATA EXTRACTS CONNECTIVITY decomp_class = DecompClass( info=epochs.info, - fmin=fmin_optimise, - fmax=fmax_optimise, indices=indices, mode=mode, - cwt_freq_resolution=cwt_freq_resolution, + fmin=fmin_optimise, + fmax=fmax_optimise, + cwt_freqs=cwt_freqs, cwt_n_cycles=cwt_n_cycles, ) epochs_transformed = decomp_class.fit_transform( @@ -85,7 +87,9 @@ def test_spectral_decomposition(DecompClass, mode): indices=decomp_class.get_transformed_indices(), sfreq=epochs.info["sfreq"], mode=mode, - cwt_freqs=cwt_freqs, + fmin=fstart, + fmax=fend, + cwt_freqs=np.arange(fstart, fend + cwt_freq_res, cwt_freq_res), cwt_n_cycles=cwt_n_cycles, ) con_mv_func = spectral_connectivity_epochs( @@ -93,7 +97,9 @@ def test_spectral_decomposition(DecompClass, mode): method=multivariate_method, indices=([seeds], [targets]), mode=mode, - cwt_freqs=cwt_freqs, + fmin=fstart, + fmax=fend, + cwt_freqs=np.arange(fstart, fend + cwt_freq_res, cwt_freq_res), cwt_n_cycles=cwt_n_cycles, ) con_bv_func = spectral_connectivity_epochs( @@ -101,7 +107,9 @@ def test_spectral_decomposition(DecompClass, mode): method=bivariate_method, indices=seed_target_indices(seeds, targets), mode=mode, - cwt_freqs=cwt_freqs, + fmin=fstart, + fmax=fend, + cwt_freqs=np.arange(fstart, fend + cwt_freq_res, cwt_freq_res), cwt_n_cycles=cwt_n_cycles, ) @@ -137,21 +145,22 @@ def test_spectral_decomposition(DecompClass, mode): ) # check connectivity for ignored freq. band lower than with optimisation # Test `fit_transform` equivalent to `fit` and `transform` separately - if mode == "multitaper": # only need to test once - decomp_class_2 = DecompClass( - info=epochs.info, - fmin=fmin_optimise, - fmax=fmax_optimise, - indices=indices, - mode=mode, - ) - decomp_class_2.fit(X=epochs[: n_epochs // 2].get_data()) - epochs_transformed_2 = decomp_class_2.transform( - X=epochs[: n_epochs // 2].get_data() - ) - assert_allclose(epochs_transformed, epochs_transformed_2) - assert_allclose(decomp_class.filters_, decomp_class_2.filters_) - assert_allclose(decomp_class.patterns_, decomp_class_2.patterns_) + decomp_class_2 = DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin_optimise, + fmax=fmax_optimise, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + ) + decomp_class_2.fit(X=epochs[: n_epochs // 2].get_data()) + epochs_transformed_2 = decomp_class_2.transform( + X=epochs[: n_epochs // 2].get_data() + ) + assert_allclose(epochs_transformed, epochs_transformed_2) + assert_allclose(decomp_class.filters_, decomp_class_2.filters_) + assert_allclose(decomp_class.patterns_, decomp_class_2.patterns_) # TEST FITTING ON ONE PIECE OF DATA AND TRANSFORMING ANOTHER con_mv_class_unseen_data = spectral_connectivity_epochs( @@ -160,7 +169,9 @@ def test_spectral_decomposition(DecompClass, mode): indices=decomp_class.get_transformed_indices(), sfreq=epochs.info["sfreq"], mode=mode, - cwt_freqs=cwt_freqs, + fmin=fstart, + fmax=fend, + cwt_freqs=np.arange(fstart, fend + cwt_freq_res, cwt_freq_res), cwt_n_cycles=cwt_n_cycles, ) assert_allclose( @@ -176,25 +187,24 @@ def test_spectral_decomposition(DecompClass, mode): # TEST GETTERS & SETTERS # Test indices internal storage and returned format - if mode == "multitaper": # only need to test once - assert np.all(np.array(decomp_class.indices) == np.array((seeds, targets))) - assert np.all( - decomp_class._indices - == _check_multivariate_indices(([seeds], [targets]), n_signals) - ) - decomp_class.set_params(indices=(targets, seeds)) - assert np.all(np.array(decomp_class.indices) == np.array((targets, seeds))) - assert np.all( - decomp_class._indices - == _check_multivariate_indices(([targets], [seeds]), n_signals) - ) + assert np.all(np.array(decomp_class.indices) == np.array((seeds, targets))) + assert np.all( + decomp_class._indices + == _check_multivariate_indices(([seeds], [targets]), n_signals) + ) + decomp_class.set_params(indices=(targets, seeds)) + assert np.all(np.array(decomp_class.indices) == np.array((targets, seeds))) + assert np.all( + decomp_class._indices + == _check_multivariate_indices(([targets], [seeds]), n_signals) + ) - # Test rank internal storage and returned format - assert np.all(decomp_class.rank == (n_signals, n_signals)) - assert np.all(decomp_class._rank == ([n_signals], [n_signals])) - decomp_class.set_params(rank=(1, 2)) - assert np.all(decomp_class.rank == (1, 2)) - assert np.all(decomp_class._rank == ([1], [2])) + # Test rank internal storage and returned format + assert np.all(decomp_class.rank == (n_signals, n_signals)) + assert np.all(decomp_class._rank == ([n_signals], [n_signals])) + decomp_class.set_params(rank=(1, 2)) + assert np.all(decomp_class.rank == (1, 2)) + assert np.all(decomp_class._rank == ([1], [2])) @pytest.mark.parametrize("DecompClass", [CaCoh, MIC]) @@ -217,11 +227,11 @@ def test_spectral_decomposition_parallel(DecompClass, mode): # RUN DECOMPOSITION decomp_class = DecompClass( info=epochs.info, - fmin=fmin, - fmax=fmax, indices=(np.arange(n_seeds), np.arange(n_targets) + n_seeds), mode=mode, - cwt_freq_resolution=1, + fmin=fmin, + fmax=fmax, + cwt_freqs=np.arange(fmin, fmax + 0.5, 0.5), cwt_n_cycles=6, n_jobs=2, # use parallelisation ) @@ -229,7 +239,8 @@ def test_spectral_decomposition_parallel(DecompClass, mode): @pytest.mark.parametrize("DecompClass", [CaCoh, MIC]) -def test_spectral_decomposition_error_catch(DecompClass): +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) +def test_spectral_decomposition_error_catch(DecompClass, mode): """Test error catching for spectral decomposition classes.""" # SIMULATE DATA n_seeds = 3 @@ -240,40 +251,25 @@ def test_spectral_decomposition_error_catch(DecompClass): n_seeds=n_seeds, n_targets=n_targets, freq_band=(fmin, fmax), rng_seed=44 ) indices = (np.arange(n_seeds), np.arange(n_targets) + n_seeds) + cwt_freqs = np.arange(fmin, fmax + 0.5, 0.5) + cwt_n_cycles = 6 # TEST BAD INITIALISATION # Test info with pytest.raises(TypeError, match="`info` must be an instance of mne.Info"): - DecompClass(info="info", fmin=fmin, fmax=fmax, indices=indices) - - # Test fmin & fmax - with pytest.raises(TypeError, match="`fmin` must be an instance of int or float"): - DecompClass(info=epochs.info, fmin="15", fmax=fmax, indices=indices) - with pytest.raises(TypeError, match="`fmax` must be an instance of int or float"): - DecompClass(info=epochs.info, fmin=fmin, fmax="20", indices=indices) - with pytest.raises(ValueError, match="`fmax` must be larger than `fmin`"): - DecompClass(info=epochs.info, fmin=fmax, fmax=fmin, indices=indices) - with pytest.raises( - ValueError, match="`fmax` cannot be larger than the Nyquist frequency" - ): - DecompClass( - info=epochs.info, - fmin=fmin, - fmax=epochs.info["sfreq"] / 2 + 1, - indices=indices, - ) + DecompClass(info="info", indices=indices) # Test indices with pytest.raises( TypeError, match="`indices` must be an instance of tuple of array-likes" ): - DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=list(indices)) + DecompClass(info=epochs.info, indices=list(indices)) with pytest.raises( TypeError, match="`indices` must be an instance of tuple of array-likes" ): - DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=(0, 1)) + DecompClass(info=epochs.info, indices=(0, 1)) with pytest.raises(ValueError, match="`indices` must have length 2"): - DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=(indices[0],)) + DecompClass(info=epochs.info, indices=(indices[0],)) with pytest.raises( ValueError, match=( @@ -281,7 +277,7 @@ def test_spectral_decomposition_error_catch(DecompClass): "target" ), ): - DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=([0, 0], [1, 2])) + DecompClass(info=epochs.info, indices=([0, 0], [1, 2])) with pytest.raises( ValueError, match=( @@ -289,16 +285,11 @@ def test_spectral_decomposition_error_catch(DecompClass): "target" ), ): - DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=([0, 1], [2, 2])) + DecompClass(info=epochs.info, indices=([0, 1], [2, 2])) with pytest.raises( ValueError, match="a negative channel index is not present in the data" ): - DecompClass( - info=epochs.info, - fmin=fmin, - fmax=fmax, - indices=([0], [(n_seeds + n_targets) * -1]), - ) + DecompClass(info=epochs.info, indices=([0], [(n_seeds + n_targets) * -1])) with pytest.raises( ValueError, match=( @@ -306,81 +297,211 @@ def test_spectral_decomposition_error_catch(DecompClass): "`info`" ), ): - DecompClass( - info=epochs.info, - fmin=fmin, - fmax=fmax, - indices=([0], [n_seeds + n_targets]), - ) + DecompClass(info=epochs.info, indices=([0], [n_seeds + n_targets])) # Test mode with pytest.raises(ValueError, match="Invalid value for the 'mode' parameter"): - DecompClass( - info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, mode="notamode" - ) + DecompClass(info=epochs.info, indices=indices, mode="notamode") + + # Test fmin & fmax + if mode in ["multitaper", "fourier"]: + with pytest.raises( + TypeError, + match=( + "`fmin` and `fmax` must not be None if `mode` is 'multitaper' or " + "'fourier'" + ), + ): + DecompClass( + info=epochs.info, indices=indices, mode=mode, fmin=None, fmax=fmax + ) + with pytest.raises( + TypeError, + match=( + "`fmin` and `fmax` must not be None if `mode` is 'multitaper' or " + "'fourier'" + ), + ): + DecompClass( + info=epochs.info, indices=indices, mode=mode, fmin=fmin, fmax=None + ) + with pytest.raises( + TypeError, match="`fmin` must be an instance of int or float" + ): + DecompClass( + info=epochs.info, indices=indices, mode=mode, fmin="15", fmax=fmax + ) + with pytest.raises( + TypeError, match="`fmax` must be an instance of int or float" + ): + DecompClass( + info=epochs.info, indices=indices, mode=mode, fmin=fmin, fmax="20" + ) + with pytest.raises(ValueError, match="`fmax` must be larger than `fmin`"): + DecompClass( + info=epochs.info, indices=indices, mode=mode, fmin=fmax, fmax=fmin + ) + with pytest.raises( + ValueError, match="`fmax` cannot be larger than the Nyquist frequency" + ): + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=epochs.info["sfreq"] / 2 + 1, + ) # Test multitaper settings - with pytest.raises( - TypeError, match="`mt_bandwidth` must be an instance of int, float, or None" - ): - DecompClass( - info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, mt_bandwidth="5" - ) - with pytest.raises(TypeError, match="`mt_adaptive` must be an instance of bool"): - DecompClass( - info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, mt_adaptive=1 - ) - with pytest.raises(TypeError, match="`mt_low_bias` must be an instance of bool"): - DecompClass( - info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, mt_low_bias=1 - ) + if mode == "multitaper": + with pytest.raises( + TypeError, match="`mt_bandwidth` must be an instance of int, float, or None" + ): + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + mt_bandwidth="5", + ) + with pytest.raises( + TypeError, match="`mt_adaptive` must be an instance of bool" + ): + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + mt_adaptive=1, + ) + with pytest.raises( + TypeError, match="`mt_low_bias` must be an instance of bool" + ): + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + mt_low_bias=1, + ) # Test wavelet settings - with pytest.raises( - TypeError, match="`cwt_freq_resolution` must be an instance of int or float" - ): - DecompClass( - info=epochs.info, - fmin=fmin, - fmax=fmax, - indices=indices, - cwt_freq_resolution="1", - ) - with pytest.raises( - TypeError, - match=( - "`cwt_n_cycles` must be an instance of int, float, or array-like of ints " - "or floats" - ), - ): - DecompClass( - info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, cwt_n_cycles="5" - ) + if mode == "cwt_morlet": + with pytest.raises( + TypeError, match="`cwt_freqs` must not be None if `mode` is 'cwt_morlet'" + ): + DecompClass(info=epochs.info, indices=indices, mode=mode, cwt_freqs=None) + with pytest.raises( + TypeError, match="`cwt_freqs` must be an instance of array-like" + ): + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + cwt_freqs="1", + ) + with pytest.raises( + ValueError, + match=( + "last entry of `cwt_freqs` cannot be larger than the Nyquist frequency" + ), + ): + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + cwt_freqs=np.array([epochs.info["sfreq"] / 2 + 1]), + cwt_n_cycles=cwt_n_cycles, + ) + with pytest.raises( + TypeError, + match="`cwt_n_cycles` must be an instance of int, float, or array-like", + ): + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + cwt_freqs=cwt_freqs, + cwt_n_cycles="5", + ) + with pytest.raises( + ValueError, + match="`cwt_n_cycles` array-like must have the same length as `cwt_freqs`", + ): + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + cwt_freqs=cwt_freqs, + cwt_n_cycles=np.full(cwt_freqs.shape[0] - 1, 5), + ) # Test n_components with pytest.raises( TypeError, match="`n_components` must be an instance of int or None" ): DecompClass( - info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, n_components="2" + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + n_components="2", ) # Test rank with pytest.raises( TypeError, match="`rank` must be an instance of tuple of ints or None" ): - DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, rank="2") + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + rank="2", + ) with pytest.raises( TypeError, match="`rank` must be an instance of tuple of ints or None" ): DecompClass( - info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, rank=("2", "2") + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + rank=("2", "2"), ) with pytest.raises(ValueError, match="`rank` must have length 2"): - DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, rank=(2,)) + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + rank=(2,), + ) with pytest.raises(ValueError, match="entries of `rank` must be > 0"): DecompClass( - info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, rank=(0, 1) + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + rank=(0, 1), ) with pytest.raises( ValueError, @@ -391,9 +512,12 @@ def test_spectral_decomposition_error_catch(DecompClass): ): DecompClass( info=epochs.info, + indices=indices, + mode=mode, fmin=fmin, fmax=fmax, - indices=indices, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, rank=(n_seeds + 1, n_targets), ) with pytest.raises( @@ -405,25 +529,52 @@ def test_spectral_decomposition_error_catch(DecompClass): ): DecompClass( info=epochs.info, + indices=indices, + mode=mode, fmin=fmin, fmax=fmax, - indices=indices, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, rank=(n_seeds, n_targets + 1), ) # Test n_jobs with pytest.raises(TypeError, match="`n_jobs` must be an instance of int"): - DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, n_jobs="1") + DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + n_jobs="1", + ) # Test verbose with pytest.raises( TypeError, match="`verbose` must be an instance of bool, str, int, or None" ): DecompClass( - info=epochs.info, fmin=fmin, fmax=fmax, indices=indices, verbose=[True] + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + verbose=[True], ) - decomp_class = DecompClass(info=epochs.info, fmin=fmin, fmax=fmax, indices=indices) + decomp_class = DecompClass( + info=epochs.info, + indices=indices, + mode=mode, + fmin=fmin, + fmax=fmax, + cwt_freqs=cwt_freqs, + cwt_n_cycles=cwt_n_cycles, + ) # TEST BAD FITTING # Test input data diff --git a/mne_connectivity/utils/docs.py b/mne_connectivity/utils/docs.py index 3fc6e5c8..dbe89e3e 100644 --- a/mne_connectivity/utils/docs.py +++ b/mne_connectivity/utils/docs.py @@ -82,10 +82,10 @@ ``mode="multitaper"``. """ -docdict["cwt_freq_resolution"] = """ -cwt_freq_resolution : int | float (default 1) - The frequency resolution of the cross-spectral density in Hz. Only used if - ``mode=cwt_morlet``. +docdict["cwt_freqs"] = """ +cwt_freqs : array of int or float | None (default None) + The frequencies of interest in Hz. Must not be ``None`` and only used if + ``mode="cwt_morlet"``. """ docdict["cwt_n_cycles"] = """ @@ -192,13 +192,15 @@ """ docdict["fmin_decoding"] = """ -fmin : int | float - The lowest frequency of interest in Hz. +fmin : int | float | None (default None) + The lowest frequency of interest in Hz. Must not be ``None`` and only used if + ``mode in ["multitaper", "fourier"]``. """ docdict["fmax_decoding"] = """ -fmax : int | float - The highest frequency of interest in Hz. +fmax : int | float | None (default None) + The highest frequency of interest in Hz. Must not be ``None`` and only used if + ``mode in ["multitaper", "fourier"]``. """ docdict["indices_decoding"] = """ From cbfcc138b0339a1bfaa54f505318fc70ce9e628e Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 6 Jun 2024 15:19:14 +0200 Subject: [PATCH 21/38] Fix platform-specific failing unit test --- mne_connectivity/decoding/tests/test_decomposition.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/decoding/tests/test_decomposition.py b/mne_connectivity/decoding/tests/test_decomposition.py index 21742f0c..6164a53a 100644 --- a/mne_connectivity/decoding/tests/test_decomposition.py +++ b/mne_connectivity/decoding/tests/test_decomposition.py @@ -158,9 +158,9 @@ def test_spectral_decomposition(DecompClass, mode): epochs_transformed_2 = decomp_class_2.transform( X=epochs[: n_epochs // 2].get_data() ) - assert_allclose(epochs_transformed, epochs_transformed_2) - assert_allclose(decomp_class.filters_, decomp_class_2.filters_) - assert_allclose(decomp_class.patterns_, decomp_class_2.patterns_) + assert_allclose(epochs_transformed, epochs_transformed_2, rtol=1e-5) + assert_allclose(decomp_class.filters_, decomp_class_2.filters_, rtol=1e-5) + assert_allclose(decomp_class.patterns_, decomp_class_2.patterns_, rtol=1e-5) # TEST FITTING ON ONE PIECE OF DATA AND TRANSFORMING ANOTHER con_mv_class_unseen_data = spectral_connectivity_epochs( From 8339d40423845cf3989c632bd06b46df9c8d9174 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 6 Jun 2024 16:13:54 +0200 Subject: [PATCH 22/38] Refactor decomposition classes --- doc/api.rst | 3 +- examples/decoding/cohy_decomposition.py | 52 +++--- mne_connectivity/decoding/__init__.py | 2 +- mne_connectivity/decoding/decomposition.py | 176 +++++++----------- .../decoding/tests/test_decomposition.py | 166 ++++++++++++----- mne_connectivity/utils/docs.py | 8 + 6 files changed, 220 insertions(+), 187 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index c64132cc..81c84418 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -60,8 +60,7 @@ connectivity, amplifying the signal-to-noise ratio of these interactions. .. autosummary:: :toctree: generated/ - CaCoh - MIC + CoherencyDecomposition Reading functions ================= diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/cohy_decomposition.py index 79923eb2..2f242f7d 100644 --- a/examples/decoding/cohy_decomposition.py +++ b/examples/decoding/cohy_decomposition.py @@ -28,7 +28,7 @@ seed_target_indices, spectral_connectivity_epochs, ) -from mne_connectivity.decoding import MIC, CaCoh +from mne_connectivity.decoding import CoherencyDecomposition ######################################################################################## # Background @@ -65,12 +65,12 @@ # setups where the rapid analysis of data is paramount, or even in offline analyses # with huge datasets. # -# These issues are addressed by the :class:`~mne_connectivity.decoding.CaCoh` and -# :class:`~mne_connectivity.decoding.MIC` decomposition classes of the decoding module. -# Here, the filters are fit for a given frequency band collectively (not each frequency -# bin!) and are stored, allowing them to be applied to the same data they were fit on -# (e.g. for offline analyses of huge datasets) or to new data (e.g. for online analyses -# of streamed data). +# These issues are addressed by the +# :class:`~mne_connectivity.decoding.CoherencyDecomposition` class of the decoding +# module. Here, the filters are fit for a given frequency band collectively (not each +# frequency bin!) and are stored, allowing them to be applied to the same data they were +# fit on (e.g. for offline analyses of huge datasets) or to new data (e.g. for online +# analyses of streamed data). # # In this example, we show how the tools of the decoding module compare to the standard # ``spectral_connectivity_...()`` functions in terms of their run time, and their @@ -152,27 +152,31 @@ # epochs to fit the filters, and then use these filters to extract the same components # from the last 30 epochs. # -# For this, we instantiate the :class:`~mne_connectivity.decoding.CaCoh` class with: the +# For this, we instantiate the +# :class:`~mne_connectivity.decoding.CoherencyDecomposition` class with: the # information about the data being fit/transformed (using an :class:`~mne.Info` object); -# the frequency band of the components we want to decompose (here 15-20 Hz); and the -# channel indices of the seeds and targets. +# the type of connectivity we want to decompose (here CaCoh); the frequency band of the +# components we want to decompose (here 15-20 Hz); and the channel indices of the seeds +# and targets. # -# Next, we call the :meth:`~mne_connectivity.decoding.CaCoh.fit` method, passing in the -# first 30 epochs of data we want to fit the filters to. Once the filters are fit, we -# can apply them to the last 30 epochs using the -# :meth:`~mne_connectivity.decoding.CaCoh.transform` method. +# Next, we call the :meth:`~mne_connectivity.decoding.CoherencyDecomposition.fit` +# method, passing in the first 30 epochs of data we want to fit the filters to. Once the +# filters are fit, we can apply them to the last 30 epochs using the +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.transform` method. # # The transformed data has shape ``(epochs x components*2 x times)``, where the new # 'channels' are organised as the seed components, then target components. For -# convenience, the :meth:`~mne_connectivity.decoding.CaCoh.get_transformed_indices` +# convenience, the +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.get_transformed_indices` # method can be used to get the ``indices`` of the transformed data for use in the # ``spectral_connectivity_...()`` functions. # %% # Fit filters to first 30 epochs -cacoh = CaCoh( +cacoh = CoherencyDecomposition( info=epochs.info, + method="cacoh", indices=indices, mode="multitaper", fmin=FMIN, @@ -193,7 +197,8 @@ # # To compute connectivity of the transformed data, it is simply a case of passing to the # ``spectral_connectivity_...()`` functions: the transformed data; the indices -# returned from :meth:`~mne_connectivity.decoding.CaCoh.get_transformed_indices`; and +# returned from +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.get_transformed_indices`; and # the corresponding bivariate method (``"coh"`` and ``"cohy"`` for CaCoh; ``"imcoh"`` # for MIC). # @@ -266,8 +271,9 @@ # %% -cacoh = CaCoh( +cacoh = CoherencyDecomposition( info=epochs.info, + method="cacoh", indices=indices, mode="multitaper", fmin=FMIN, @@ -405,9 +411,10 @@ ######################################################################################## # There are two equivalent options for fitting and transforming the same data: 1) -# passing the data to the :meth:`~mne_connectivity.decoding.MIC.fit` and -# :meth:`~mne_connectivity.decoding.MIC.transform` methods sequentially; or 2) using the -# combined :meth:`~mne_connectivity.decoding.MIC.fit_transform` method. +# passing the data to the :meth:`~mne_connectivity.decoding.CoherencyDecomposition.fit` +# and :meth:`~mne_connectivity.decoding.CoherencyDecomposition.transform` methods +# sequentially; or 2) using the combined +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.fit_transform` method. # # We use the latter approach below, fitting the filters to the 15-20 Hz band and using # the ``"imcoh"`` method in the call to the ``spectral_connectivity_...()`` functions. @@ -415,8 +422,9 @@ # %% -mic = MIC( +mic = CoherencyDecomposition( info=epochs.info, + method="mic", indices=(seeds, targets), mode="multitaper", fmin=FMIN, diff --git a/mne_connectivity/decoding/__init__.py b/mne_connectivity/decoding/__init__.py index 444470ae..f63899f3 100644 --- a/mne_connectivity/decoding/__init__.py +++ b/mne_connectivity/decoding/__init__.py @@ -1 +1 @@ -from .decomposition import MIC, CaCoh +from .decomposition import CoherencyDecomposition diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index 2062d6d5..07cc689f 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -11,24 +11,70 @@ from mne.time_frequency import csd_array_fourier, csd_array_morlet, csd_array_multitaper from mne.utils import _check_option, _validate_type -from ..spectral.epochs_multivariate import ( - _CaCohEst, - _check_rank_input, - _EpochMeanMultivariateConEstBase, - _MICEst, -) +from ..spectral.epochs_multivariate import _CaCohEst, _check_rank_input, _MICEst from ..utils import _check_multivariate_indices, fill_doc -class _AbstractDecompositionBase(BaseEstimator, TransformerMixin): - """ABC for multivariate connectivity signal decomposition.""" +@fill_doc +class CoherencyDecomposition(BaseEstimator, TransformerMixin): + """Decompose connectivity sources using multivariate coherency-based methods. + + Parameters + ---------- + %(info_decoding)s + %(method_decoding)s + %(indices_decoding)s + %(mode)s + %(fmin_decoding)s + %(fmax_decoding)s + %(mt_bandwidth)s + %(mt_adaptive)s + %(mt_low_bias)s + %(cwt_freqs)s + %(cwt_n_cycles)s + %(n_components)s + %(rank)s + %(n_jobs)s + %(verbose)s + + Attributes + ---------- + %(filters_)s + %(patterns_)s + + Notes + ----- + The multivariate methods maximise connectivity between a set of seed and target + signals in a frequency-resolved manner. The maximisation of connectivity involves + fitting spatial filters to the cross-spectral density of the seed and target data, + alongisde which spatial patterns of the contributions to connectivity can be + computed :footcite:`HaufeEtAl2014`. + + Once fit, the filters can be used to transform data into the underlying connectivity + components. Connectivity can be computed on this transformed data using the + bivariate coherency-based methods of the + `mne_connectivity.spectral_connectivity_epochs` and + `mne_connectivity.spectral_connectivity_time` functions. These bivariate methods + are: + + * ``"cohy"`` and ``"coh"`` for CaCoh :footcite:`VidaurreEtAl2019` + * ``"imcoh"`` for MIC :footcite:`EwaldEtAl2012` + + The approach taken here is to optimise the connectivity in a given frequency band. + Frequency bin-wise optimisation is offered in the multivariate coherency-based + methods of the `mne_connectivity.spectral_connectivity_epochs` and + `mne_connectivity.spectral_connectivity_time` functions. + + References + ---------- + .. footbibliography:: + """ filters_: Optional[tuple] = None patterns_: Optional[tuple] = None _indices: Optional[tuple] = None _rank: Optional[tuple] = None - _conn_estimator: Optional[_EpochMeanMultivariateConEstBase] = None @property def indices(self): @@ -61,6 +107,7 @@ def rank(self, rank): def __init__( self, info, + method, indices, mode="multitaper", fmin=None, @@ -79,6 +126,12 @@ def __init__( # Validate inputs _validate_type(info, Info, "`info`", "mne.Info") + _check_option("method", method, ("cacoh", "mic")) + if method == "cacoh": + _conn_estimator = _CaCohEst + else: + _conn_estimator = _MICEst + _validate_type(indices, tuple, "`indices`", "tuple of array-likes") if len(indices) != 2: raise ValueError("`indices` must have length 2") @@ -157,6 +210,7 @@ def __init__( # Store inputs self.info = info + self._conn_estimator = _conn_estimator self._indices = _indices # uses getter/setter for public parameter self.mode = mode self.fmin = fmin @@ -218,7 +272,7 @@ def fit(self, X, y=None): Returns ------- - self : instance of CaCoh | MIC + self : instance of CoherencyDecomposition The modified class instance. """ # validate input data @@ -405,105 +459,3 @@ def get_transformed_indices(self): np.arange(self.n_components), np.arange(self.n_components) + self.n_components, ) - - -@fill_doc -class CaCoh(_AbstractDecompositionBase): - """Decompose connectivity sources using canonical coherency (CaCoh). - - CaCoh is a multivariate approach to maximise coherency/coherence between a set of - seed and target signals in a frequency-resolved manner :footcite:`VidaurreEtAl2019`. - The maximisation of connectivity involves fitting spatial filters to the - cross-spectral density of the seed and target data, alongisde which spatial patterns - of the contributions to connectivity can be computed :footcite:`HaufeEtAl2014`. - - Once fit, the filters can be used to transform data into the underlying connectivity - components. Connectivity can be computed on this transformed data using the - ``"coh"`` and ``"cohy"`` methods of the - `mne_connectivity.spectral_connectivity_epochs` and - `mne_connectivity.spectral_connectivity_time` functions. - - The approach taken here is to optimise the connectivity in a given frequency band. - Frequency bin-wise optimisation is offered in the ``"cacoh"`` method of the - `mne_connectivity.spectral_connectivity_epochs` and - `mne_connectivity.spectral_connectivity_time` functions. - - Parameters - ---------- - %(info_decoding)s - %(indices_decoding)s - %(mode)s - %(fmin_decoding)s - %(fmax_decoding)s - %(mt_bandwidth)s - %(mt_adaptive)s - %(mt_low_bias)s - %(cwt_freqs)s - %(cwt_n_cycles)s - %(n_components)s - %(rank)s - %(n_jobs)s - %(verbose)s - - Attributes - ---------- - %(filters_)s - %(patterns_)s - - References - ---------- - .. footbibliography:: - """ - - _conn_estimator = _CaCohEst - - -@fill_doc -class MIC(_AbstractDecompositionBase): - """Decompose connectivity sources using maximised imaginary coherency (MIC). - - MIC is a multivariate approach to maximise the imaginary part of coherency between a - set of seed and target signals in a frequency-resolved manner - :footcite:`EwaldEtAl2012`. The maximisation of connectivity involves fitting spatial - filters to the cross-spectral density of the seed and target data, alongisde which - spatial patterns of the contributions to connectivity can be computed - :footcite:`HaufeEtAl2014`. - - Once fit, the filters can be used to transform data into the underlying connectivity - components. Connectivity can be computed on this transformed data using the - ``"imcoh"`` method of the `mne_connectivity.spectral_connectivity_epochs` and - `mne_connectivity.spectral_connectivity_time` functions. - - The approach taken here is to optimise the connectivity in a given frequency band. - Frequency bin-wise optimisation is offered in the ``"mic"`` method of the - `mne_connectivity.spectral_connectivity_epochs` and - `mne_connectivity.spectral_connectivity_time` functions. - - Parameters - ---------- - %(info_decoding)s - %(indices_decoding)s - %(mode)s - %(fmin_decoding)s - %(fmax_decoding)s - %(mt_bandwidth)s - %(mt_adaptive)s - %(mt_low_bias)s - %(cwt_freqs)s - %(cwt_n_cycles)s - %(n_components)s - %(rank)s - %(n_jobs)s - %(verbose)s - - Attributes - ---------- - %(filters_)s - %(patterns_)s - - References - ---------- - .. footbibliography:: - """ - - _conn_estimator = _MICEst diff --git a/mne_connectivity/decoding/tests/test_decomposition.py b/mne_connectivity/decoding/tests/test_decomposition.py index 6164a53a..f6a4e7db 100644 --- a/mne_connectivity/decoding/tests/test_decomposition.py +++ b/mne_connectivity/decoding/tests/test_decomposition.py @@ -7,13 +7,13 @@ seed_target_indices, spectral_connectivity_epochs, ) -from mne_connectivity.decoding import MIC, CaCoh +from mne_connectivity.decoding import CoherencyDecomposition from mne_connectivity.utils import _check_multivariate_indices -@pytest.mark.parametrize("DecompClass", [CaCoh, MIC]) +@pytest.mark.parametrize("method", ["cacoh", "mic"]) @pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) -def test_spectral_decomposition(DecompClass, mode): +def test_spectral_decomposition(method, mode): """Test spectral decomposition classes run and give expected results.""" # SIMULATE DATA # Settings @@ -61,16 +61,21 @@ def test_spectral_decomposition(DecompClass, mode): ) indices = (seeds, targets) - bivariate_method = "coh" if DecompClass == CaCoh else "imcoh" - multivariate_method = "cacoh" if DecompClass == CaCoh else "mic" + if method == "cacoh": + bivariate_method = "coh" + multivariate_method = "cacoh" + else: + bivariate_method = "imcoh" + multivariate_method = "mic" cwt_freq_res = 0.5 cwt_freqs = np.arange(fmin_optimise, fmax_optimise + cwt_freq_res, cwt_freq_res) cwt_n_cycles = 6 # TEST FITTING AND TRANSFORMING SAME DATA EXTRACTS CONNECTIVITY - decomp_class = DecompClass( + decomp_class = CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin_optimise, @@ -145,8 +150,9 @@ def test_spectral_decomposition(DecompClass, mode): ) # check connectivity for ignored freq. band lower than with optimisation # Test `fit_transform` equivalent to `fit` and `transform` separately - decomp_class_2 = DecompClass( + decomp_class_2 = CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin_optimise, @@ -207,9 +213,9 @@ def test_spectral_decomposition(DecompClass, mode): assert np.all(decomp_class._rank == ([1], [2])) -@pytest.mark.parametrize("DecompClass", [CaCoh, MIC]) +@pytest.mark.parametrize("method", ["cacoh", "mic"]) @pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) -def test_spectral_decomposition_parallel(DecompClass, mode): +def test_spectral_decomposition_parallel(method, mode): """Test spectral decomposition classes run with parallelisation.""" # SIMULATE DATA n_seeds = 3 @@ -225,8 +231,9 @@ def test_spectral_decomposition_parallel(DecompClass, mode): ) # RUN DECOMPOSITION - decomp_class = DecompClass( + decomp_class = CoherencyDecomposition( info=epochs.info, + method=method, indices=(np.arange(n_seeds), np.arange(n_targets) + n_seeds), mode=mode, fmin=fmin, @@ -238,9 +245,9 @@ def test_spectral_decomposition_parallel(DecompClass, mode): decomp_class.fit_transform(X=epochs.get_data()) -@pytest.mark.parametrize("DecompClass", [CaCoh, MIC]) +@pytest.mark.parametrize("method", ["cacoh", "mic"]) @pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) -def test_spectral_decomposition_error_catch(DecompClass, mode): +def test_spectral_decomposition_error_catch(method, mode): """Test error catching for spectral decomposition classes.""" # SIMULATE DATA n_seeds = 3 @@ -257,19 +264,19 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): # TEST BAD INITIALISATION # Test info with pytest.raises(TypeError, match="`info` must be an instance of mne.Info"): - DecompClass(info="info", indices=indices) + CoherencyDecomposition(info="info", method=method, indices=indices) # Test indices with pytest.raises( TypeError, match="`indices` must be an instance of tuple of array-likes" ): - DecompClass(info=epochs.info, indices=list(indices)) + CoherencyDecomposition(info=epochs.info, method=method, indices=list(indices)) with pytest.raises( TypeError, match="`indices` must be an instance of tuple of array-likes" ): - DecompClass(info=epochs.info, indices=(0, 1)) + CoherencyDecomposition(info=epochs.info, method=method, indices=(0, 1)) with pytest.raises(ValueError, match="`indices` must have length 2"): - DecompClass(info=epochs.info, indices=(indices[0],)) + CoherencyDecomposition(info=epochs.info, method=method, indices=(indices[0],)) with pytest.raises( ValueError, match=( @@ -277,7 +284,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): "target" ), ): - DecompClass(info=epochs.info, indices=([0, 0], [1, 2])) + CoherencyDecomposition( + info=epochs.info, method=method, indices=([0, 0], [1, 2]) + ) with pytest.raises( ValueError, match=( @@ -285,11 +294,15 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): "target" ), ): - DecompClass(info=epochs.info, indices=([0, 1], [2, 2])) + CoherencyDecomposition( + info=epochs.info, method=method, indices=([0, 1], [2, 2]) + ) with pytest.raises( ValueError, match="a negative channel index is not present in the data" ): - DecompClass(info=epochs.info, indices=([0], [(n_seeds + n_targets) * -1])) + CoherencyDecomposition( + info=epochs.info, method=method, indices=([0], [(n_seeds + n_targets) * -1]) + ) with pytest.raises( ValueError, match=( @@ -297,11 +310,15 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): "`info`" ), ): - DecompClass(info=epochs.info, indices=([0], [n_seeds + n_targets])) + CoherencyDecomposition( + info=epochs.info, method=method, indices=([0], [n_seeds + n_targets]) + ) # Test mode with pytest.raises(ValueError, match="Invalid value for the 'mode' parameter"): - DecompClass(info=epochs.info, indices=indices, mode="notamode") + CoherencyDecomposition( + info=epochs.info, method=method, indices=indices, mode="notamode" + ) # Test fmin & fmax if mode in ["multitaper", "fourier"]: @@ -312,8 +329,13 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): "'fourier'" ), ): - DecompClass( - info=epochs.info, indices=indices, mode=mode, fmin=None, fmax=fmax + CoherencyDecomposition( + info=epochs.info, + method=method, + indices=indices, + mode=mode, + fmin=None, + fmax=fmax, ) with pytest.raises( TypeError, @@ -322,30 +344,51 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): "'fourier'" ), ): - DecompClass( - info=epochs.info, indices=indices, mode=mode, fmin=fmin, fmax=None + CoherencyDecomposition( + info=epochs.info, + method=method, + indices=indices, + mode=mode, + fmin=fmin, + fmax=None, ) with pytest.raises( TypeError, match="`fmin` must be an instance of int or float" ): - DecompClass( - info=epochs.info, indices=indices, mode=mode, fmin="15", fmax=fmax + CoherencyDecomposition( + info=epochs.info, + method=method, + indices=indices, + mode=mode, + fmin="15", + fmax=fmax, ) with pytest.raises( TypeError, match="`fmax` must be an instance of int or float" ): - DecompClass( - info=epochs.info, indices=indices, mode=mode, fmin=fmin, fmax="20" + CoherencyDecomposition( + info=epochs.info, + method=method, + indices=indices, + mode=mode, + fmin=fmin, + fmax="20", ) with pytest.raises(ValueError, match="`fmax` must be larger than `fmin`"): - DecompClass( - info=epochs.info, indices=indices, mode=mode, fmin=fmax, fmax=fmin + CoherencyDecomposition( + info=epochs.info, + method=method, + indices=indices, + mode=mode, + fmin=fmax, + fmax=fmin, ) with pytest.raises( ValueError, match="`fmax` cannot be larger than the Nyquist frequency" ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -357,8 +400,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): with pytest.raises( TypeError, match="`mt_bandwidth` must be an instance of int, float, or None" ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -368,8 +412,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): with pytest.raises( TypeError, match="`mt_adaptive` must be an instance of bool" ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -379,8 +424,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): with pytest.raises( TypeError, match="`mt_low_bias` must be an instance of bool" ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -393,12 +439,19 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): with pytest.raises( TypeError, match="`cwt_freqs` must not be None if `mode` is 'cwt_morlet'" ): - DecompClass(info=epochs.info, indices=indices, mode=mode, cwt_freqs=None) + CoherencyDecomposition( + info=epochs.info, + method=method, + indices=indices, + mode=mode, + cwt_freqs=None, + ) with pytest.raises( TypeError, match="`cwt_freqs` must be an instance of array-like" ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, cwt_freqs="1", @@ -409,8 +462,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): "last entry of `cwt_freqs` cannot be larger than the Nyquist frequency" ), ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, cwt_freqs=np.array([epochs.info["sfreq"] / 2 + 1]), @@ -420,8 +474,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): TypeError, match="`cwt_n_cycles` must be an instance of int, float, or array-like", ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, cwt_freqs=cwt_freqs, @@ -431,8 +486,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): ValueError, match="`cwt_n_cycles` array-like must have the same length as `cwt_freqs`", ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, cwt_freqs=cwt_freqs, @@ -443,8 +499,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): with pytest.raises( TypeError, match="`n_components` must be an instance of int or None" ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -458,8 +515,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): with pytest.raises( TypeError, match="`rank` must be an instance of tuple of ints or None" ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -471,8 +529,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): with pytest.raises( TypeError, match="`rank` must be an instance of tuple of ints or None" ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -482,8 +541,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): rank=("2", "2"), ) with pytest.raises(ValueError, match="`rank` must have length 2"): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -493,8 +553,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): rank=(2,), ) with pytest.raises(ValueError, match="entries of `rank` must be > 0"): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -510,8 +571,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): "channels in `indices`" ), ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -527,8 +589,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): "channels in `indices`" ), ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -540,8 +603,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): # Test n_jobs with pytest.raises(TypeError, match="`n_jobs` must be an instance of int"): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -555,8 +619,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): with pytest.raises( TypeError, match="`verbose` must be an instance of bool, str, int, or None" ): - DecompClass( + CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, @@ -566,8 +631,9 @@ def test_spectral_decomposition_error_catch(DecompClass, mode): verbose=[True], ) - decomp_class = DecompClass( + decomp_class = CoherencyDecomposition( info=epochs.info, + method=method, indices=indices, mode=mode, fmin=fmin, diff --git a/mne_connectivity/utils/docs.py b/mne_connectivity/utils/docs.py index dbe89e3e..a42702cd 100644 --- a/mne_connectivity/utils/docs.py +++ b/mne_connectivity/utils/docs.py @@ -191,6 +191,14 @@ subsequent input data. """ +docdict["method_decoding"] = """ +method : str + The multivariate method to use for the decomposition. Can be: + + * ``"cacoh"`` - Canonical Coherency (CaCoh) :footcite:`VidaurreEtAl2019` + * ``"mic"`` - Maximised Imaginary part of Coherency (MIC) :footcite:`EwaldEtAl2012` +""" + docdict["fmin_decoding"] = """ fmin : int | float | None (default None) The lowest frequency of interest in Hz. Must not be ``None`` and only used if From bdfb626382a8172fb5afad2c79ddafeb0e0e47fb Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 10 Jun 2024 15:45:36 +0200 Subject: [PATCH 23/38] Add test reminder --- mne_connectivity/decoding/tests/test_decomposition.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mne_connectivity/decoding/tests/test_decomposition.py b/mne_connectivity/decoding/tests/test_decomposition.py index f6a4e7db..36fb42ed 100644 --- a/mne_connectivity/decoding/tests/test_decomposition.py +++ b/mne_connectivity/decoding/tests/test_decomposition.py @@ -191,6 +191,8 @@ def test_spectral_decomposition(method, mode): atol=similarity_thresh, ) # check connectivity for optimised freq. band similarly low for seen & unseen + # XXX: TEST FILTERS/PATTERNS HAS CORRECT SHAPE WHEN N_COMPONENTS > 1 SUPPORTED + # TEST GETTERS & SETTERS # Test indices internal storage and returned format assert np.all(np.array(decomp_class.indices) == np.array((seeds, targets))) From a0e9560666c70a96ba45268a6da510414db6f092 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 10 Jun 2024 17:35:29 +0200 Subject: [PATCH 24/38] Add decomposition plotting --- examples/decoding/cohy_decomposition.py | 66 +++++- mne_connectivity/decoding/decomposition.py | 206 ++++++++++++++++++ .../decoding/tests/test_decomposition.py | 40 ++++ mne_connectivity/utils/docs.py | 5 + 4 files changed, 312 insertions(+), 5 deletions(-) diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/cohy_decomposition.py index 2f242f7d..58dafb3e 100644 --- a/examples/decoding/cohy_decomposition.py +++ b/examples/decoding/cohy_decomposition.py @@ -45,13 +45,15 @@ # Coherency-based methods are popular approaches for analysing connectivity, capturing # correlation between signals in the frequency domain. Various coherency-based # multivariate methods exist, including: canonical coherency (CaCoh; multivariate -# measure of coherency/coherence); and maximised imaginary coherency (MIC; multivariate -# measure of the imaginary part of coherency). +# measure of coherency/coherence) :footcite:`VidaurreEtAl2019` ; and maximised imaginary +# coherency (MIC; multivariate measure of the imaginary part of coherency) +# :footcite:`EwaldEtAl2012`. # # These methods are described in detail in the following examples: -# - comparison of coherency-based methods - :doc:`../compare_coherency_methods` -# - CaCoh - :doc:`../cacoh` -# - MIC - :doc:`../mic_mim` +# +# - comparison of coherency-based methods - :doc:`../compare_coherency_methods` +# - CaCoh - :doc:`../cacoh` +# - MIC - :doc:`../mic_mim` # # The CaCoh and MIC methods work by finding spatial filters that decompose the data into # components of connectivity, and applying them to the data. With the implementations @@ -526,6 +528,55 @@ f"{func_duration:.2f} s" ) +######################################################################################## +# Visualising filters and patterns +# -------------------------------- +# In addition to the connectivity scores, useful insights about the data can be gained +# by visualising the topographies of the filters and patterns, which represent two +# complementary aspects: +# +# - The filters represent how the connectivity sources are extracted from the channel +# data, akin to an inverse model. +# - The patterns represent how the channel data is formed by the connectivity sources, +# akin to a forward model. +# +# This distinction is discussed further in Haufe *et al.* (2014) +# :footcite:`HaufeEtAl2014`, but in short: **the patterns should be used to interpret +# the contribution of distinct brain regions/sensors to a given component of +# connectivity**. Accordingly, keep in mind that the filters and patterns are not a +# replacement for source reconstruction, as without this the patterns will still only +# tell you about the spatial contributions of sensors, not underlying brain regions, +# to connectivity. +# +# Visualising these topographies can be done using the +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.plot_filters` and +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.plot_patterns` methods. +# +# When interpreting patterns, note that the absolute value reflects the strength of the +# contribution to connectivity, however the sign differences can be used to visualise +# the orientation of the underlying dipole sources. The spatial patterns are **not** +# bound between :math:`[-1, 1]`. +# +# Plotting the patterns below, we can infer the existence of postcentral, generally +# medial dipole sources contributing to the connectivity between sensors over left and +# right hemispheres at 15-20 Hz. + +# %% + +# Plot patterns +mic.plot_patterns(epochs.info, sensors="m.", size=2) + +######################################################################################## +# For comparison we can also plot the filters, and here we see that they show a very +# similar topography to the patterns. However, this is not always the case, and you +# should never confuse the information represented by the filters and patterns, which +# can lead to very incorrect interpretations of the data :footcite:`HaufeEtAl2014`. + +# %% + +# Plot filters +mic.plot_filters(epochs.info, sensors="m.", size=2) + ######################################################################################## # Limitations # ----------- @@ -554,4 +605,9 @@ # Ultimately, there are distinct advantages and disadvantages to both approaches, and # one may be more suitable than the other depending on your use case. +######################################################################################## +# References +# ---------- +# .. footbibliography:: + # %% diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index 07cc689f..0b1e7d5a 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -2,11 +2,15 @@ # # License: BSD (3-clause) +import copy as cp from typing import Optional import numpy as np from mne import Info +from mne._fiff.pick import pick_info from mne.decoding.mixin import TransformerMixin +from mne.defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT +from mne.evoked import EvokedArray from mne.fixes import BaseEstimator from mne.time_frequency import csd_array_fourier, csd_array_morlet, csd_array_multitaper from mne.utils import _check_option, _validate_type @@ -459,3 +463,205 @@ def get_transformed_indices(self): np.arange(self.n_components), np.arange(self.n_components) + self.n_components, ) + + @fill_doc + def plot_patterns(self, info, **kwargs): + """Plot topographic patterns of components. + + The patterns explain how the measured data was generated from the + neural sources (a.k.a. the forward model) :footcite:`HaufeEtAl2014`. + + Seed and target patterns are plotted separately. + + Parameters + ---------- + %(info_not_none)s + components : float | array of float | None + The patterns to plot. If ``None``, all components will be shown. + %(average_plot_evoked_topomap)s + %(ch_type_topomap)s + scalings : dict | float | None + The scalings of the channel types to be applied for plotting. + If None, defaults to ``dict(eeg=1e6, grad=1e13, mag=1e15)``. + %(sensors_topomap)s + %(show_names_topomap)s + %(mask_patterns_topomap)s + %(mask_params_topomap)s + %(contours_topomap)s + %(outlines_topomap)s + %(sphere_topomap_auto)s + %(image_interp_topomap)s + %(extrapolate_topomap)s + %(border_topomap)s + %(res_topomap)s + %(size_topomap)s + %(cmap_topomap)s + %(vlim_plot_topomap)s + %(cnorm)s + %(colorbar_topomap)s + %(cbar_fmt_topomap)s + %(units_topomap)s + %(axes_evoked_plot_topomap)s + name_format : str | None + String format for topomap values. ``None`` defaults to f"{method}%%01d". + %(nrows_ncols_topomap)s + %(show)s + + Returns + ------- + figs : list of instance of matplotlib.figure.Figure + The seed and target figures, respectively. + """ + if self.patterns_ is None: + raise RuntimeError( + "no patterns are available, please call the `fit` method first" + ) + + return self._plot_filters_patterns( + (self.patterns_[0].T, self.patterns_[1].T), info, **kwargs + ) + + @fill_doc + def plot_filters(self, info, **kwargs): + """Plot topographic filters of components. + + The filters are used to extract discriminant neural sources from the measured + data (a.k.a. the backward model). :footcite:`HaufeEtAl2014`. + + Seed and target filters are plotted separately. + + Parameters + ---------- + %(info_not_none)s + components : float | array of float | None + The filters to plot. If ``None``, all components will be shown. + %(average_plot_evoked_topomap)s + %(ch_type_topomap)s + scalings : dict | float | None + The scalings of the channel types to be applied for plotting. + If None, defaults to ``dict(eeg=1e6, grad=1e13, mag=1e15)``. + %(sensors_topomap)s + %(show_names_topomap)s + %(mask_patterns_topomap)s + %(mask_params_topomap)s + %(contours_topomap)s + %(outlines_topomap)s + %(sphere_topomap_auto)s + %(image_interp_topomap)s + %(extrapolate_topomap)s + %(border_topomap)s + %(res_topomap)s + %(size_topomap)s + %(cmap_topomap)s + %(vlim_plot_topomap)s + %(cnorm)s + %(colorbar_topomap)s + %(cbar_fmt_topomap)s + %(units_topomap)s + %(axes_evoked_plot_topomap)s + name_format : str | None + String format for topomap values. ``None`` defaults to f"{method}%%01d". + %(nrows_ncols_topomap)s + %(show)s + + Returns + ------- + figs : list of instance of matplotlib.figure.Figure + The seed and target figures, respectively. + """ + if self.filters_ is None: + raise RuntimeError( + "no filters are available, please call the `fit` method first" + ) + + return self._plot_filters_patterns(self.filters_, info, **kwargs) + + def _plot_filters_patterns( + self, + plot_data, + info, + components=None, + average=None, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%.1E", + units=None, + axes=None, + name_format=None, + nrows=1, + ncols="auto", + show=True, + ): + """Plot filters/targets for components.""" + # Sort inputs + _validate_type(info, Info, "`info`", "mne.Info") + if units is None: + units = "AU" + if components is None: + components = np.arange(self.n_components) + + # plot seeds and targets + figs = [] + for group_idx, group_name in zip([0, 1], ["Seeds", "Targets"]): + # create info for seeds/targets + group_info = cp.deepcopy(info) + group_info = pick_info(group_info, self.indices[group_idx], copy=False) + with group_info._unlock(): + group_info["sfreq"] = 1.0 # 1 component per time point + # create Evoked object + evoked = EvokedArray(plot_data[group_idx], group_info, tmin=0) + # then call plot_topomap + figs.append( + evoked.plot_topomap( + times=components, + average=average, + ch_type=ch_type, + scalings=scalings, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + time_format=f"{self._conn_estimator.name}%01d" + if name_format is None + else name_format, + nrows=nrows, + ncols=ncols, + show=False, # set Seeds/Targets suptitle first + ) + ) + figs[-1].suptitle(group_name) # differentiate seeds from targets + if show: + figs[-1].show() + + return figs diff --git a/mne_connectivity/decoding/tests/test_decomposition.py b/mne_connectivity/decoding/tests/test_decomposition.py index f6a4e7db..75912115 100644 --- a/mne_connectivity/decoding/tests/test_decomposition.py +++ b/mne_connectivity/decoding/tests/test_decomposition.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from mne.channels import make_dig_montage, make_standard_montage from numpy.testing import assert_allclose from mne_connectivity import ( @@ -212,6 +213,27 @@ def test_spectral_decomposition(method, mode): assert np.all(decomp_class.rank == (1, 2)) assert np.all(decomp_class._rank == ([1], [2])) + # TEST PLOTTING + # Test plot filters/patterns + # use standard montage to avoid errors around weird fiducial positions + standard_1020_pos = make_standard_montage("standard_1020").get_positions() + epochs.info.set_montage( + make_dig_montage( + ch_pos={ + name: [idx, idx, idx] + for idx, name in enumerate(epochs.info["ch_names"]) + }, # avoid overlapping positions for channels (raises error) + nasion=standard_1020_pos["nasion"], + lpa=standard_1020_pos["lpa"], + rpa=standard_1020_pos["rpa"], + ) + ) + for plot in (decomp_class.plot_filters, decomp_class.plot_patterns): + # XXX: required for this to be picked up by coverage + figs = plot(epochs.info, components=0, units="A.U.", show=False) + figs = plot(epochs.info, components=None, units=None, show=False) + assert len(figs) == 2 + @pytest.mark.parametrize("method", ["cacoh", "mic"]) @pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) @@ -659,6 +681,19 @@ def test_spectral_decomposition_error_catch(method, mode): ): decomp_class.transform(X=epochs.get_data()) + # TEST PLOTTING BEFORE FITTING + with pytest.raises( + RuntimeError, + match="no filters are available, please call the `fit` method first", + ): + decomp_class.plot_filters(epochs.info) + + with pytest.raises( + RuntimeError, + match="no patterns are available, please call the `fit` method first", + ): + decomp_class.plot_patterns(epochs.info) + decomp_class.fit(X=epochs.get_data()) # TEST BAD TRANSFORMING @@ -668,3 +703,8 @@ def test_spectral_decomposition_error_catch(method, mode): decomp_class.transform(X=epochs.get_data()[0, 0]) with pytest.raises(ValueError, match="`X` does not match Info"): decomp_class.transform(X=epochs.get_data()[:, :-1]) + + # TEST BAD PLOTTING + for plot in (decomp_class.plot_filters, decomp_class.plot_patterns): + with pytest.raises(TypeError, match="`info` must be an instance of mne.Info"): + plot({"info": epochs.info}) diff --git a/mne_connectivity/utils/docs.py b/mne_connectivity/utils/docs.py index a42702cd..bc4d0461 100644 --- a/mne_connectivity/utils/docs.py +++ b/mne_connectivity/utils/docs.py @@ -8,6 +8,8 @@ except ImportError: from mne.externals.doccer import indentcount_lines as _indentcount_lines # noqa +from mne.utils.docs import docdict as mne_docdict + ############################################################################## # Define our standard documentation entries @@ -250,6 +252,9 @@ filters for the seed and target data, respectively. """ +for key, val in mne_docdict.items(): + if key not in docdict: + docdict[key] = val docdict_indented = dict() # type: ignore From 840f4f61959b33eff4a1cd77576cd95b5b43a9ef Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 12 Jun 2024 16:51:09 +0200 Subject: [PATCH 25/38] Update tests and fix getter/setters --- mne_connectivity/decoding/decomposition.py | 19 +++++++++++---- .../decoding/tests/test_decomposition.py | 23 +++++++++++++++++-- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index 07cc689f..b9436621 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -97,12 +97,17 @@ def rank(self): :meta private: """ - return (self._rank[0][0], self._rank[1][0]) + if self._rank is not None: + return (self._rank[0][0], self._rank[1][0]) + return None @rank.setter def rank(self, rank): """Set ``rank`` parameter using the input format.""" - self._rank = ([rank[0]], [rank[1]]) + if rank is None: + self._rank = None + else: + self._rank = ([rank[0]], [rank[1]]) def __init__( self, @@ -318,16 +323,20 @@ def _check_X(self, X, ndim): def _get_rank_and_ncomps_from_X(self, X): """Get/validate rank and n_components parameters using the data.""" # compute rank from data if necessary / check it is valid for the indices - self._rank = _check_rank_input(self._rank, X, self._indices) + rank = _check_rank_input(self._rank, X, self._indices) # set n_components if necessary / check it is valid for the rank if self.n_components is None: - self.n_components = np.min(self.rank) - elif self.n_components > np.min(self.rank): + self.n_components = np.min(rank) + elif self.n_components > np.min(rank): raise ValueError( "`n_components` is greater than the minimum rank of the data" ) + # set rank if necessary + if self._rank is None: + self._rank = rank + def _compute_csd(self, X): """Compute the cross-spectral density of the input data.""" # XXX: fix csd returning [fmin +1 bin to fmax -1 bin] frequencies diff --git a/mne_connectivity/decoding/tests/test_decomposition.py b/mne_connectivity/decoding/tests/test_decomposition.py index 36fb42ed..7ecc07bf 100644 --- a/mne_connectivity/decoding/tests/test_decomposition.py +++ b/mne_connectivity/decoding/tests/test_decomposition.py @@ -214,6 +214,9 @@ def test_spectral_decomposition(method, mode): assert np.all(decomp_class.rank == (1, 2)) assert np.all(decomp_class._rank == ([1], [2])) + # Test rank can be reset to default + decomp_class.set_params(rank=None) + @pytest.mark.parametrize("method", ["cacoh", "mic"]) @pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) @@ -645,14 +648,29 @@ def test_spectral_decomposition_error_catch(method, mode): ) # TEST BAD FITTING - # Test input data + # Test input data format with pytest.raises(TypeError, match="`X` must be an instance of NumPy array"): decomp_class.fit(X=epochs.get_data().tolist()) with pytest.raises(ValueError, match="Invalid value for the '`X.ndim`' parameter"): decomp_class.fit(X=epochs.get_data()[0]) with pytest.raises(ValueError, match="`X` does not match Info"): decomp_class.fit(X=epochs.get_data()[:, :-1]) - # XXX: Add test for rank of X being <= n_components when n_components can be > 1 + # Test rank of input data is compatible with n_components + decomp_class.set_params(n_components=3) + with pytest.raises( + ValueError, match="`n_components` is greater than the minimum rank of the data" + ): + rank_def_data = epochs.get_data(copy=True) + rank_def_data[:, n_seeds - 1] = rank_def_data[:, n_seeds - 2] + decomp_class.fit(X=rank_def_data) + with pytest.raises( + ValueError, match="`n_components` is greater than the minimum rank of the data" + ): + rank_def_data = epochs.get_data(copy=True) + rank_def_data[:, n_seeds + n_targets - 1] = rank_def_data[ + :, n_seeds + n_targets - 2 + ] + decomp_class.fit(X=rank_def_data) # TEST TRANSFORM BEFORE FITTING with pytest.raises( @@ -661,6 +679,7 @@ def test_spectral_decomposition_error_catch(method, mode): ): decomp_class.transform(X=epochs.get_data()) + decomp_class.set_params(n_components=None) # reset to default decomp_class.fit(X=epochs.get_data()) # TEST BAD TRANSFORMING From b8b57bbcf11fba7013b02ff8032b590cf6982a6d Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 12 Jun 2024 19:59:42 +0200 Subject: [PATCH 26/38] Switch from matmul to at --- mne_connectivity/spectral/epochs_multivariate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 9ad63b15..3302e1ff 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -459,8 +459,8 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, U_bar_aa, U_bar_bb, con_i): beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] # Part of Eqs. 46 & 47; i.e. transform filters to channel space - alpha_Ubar = np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3)) - beta_Ubar = np.matmul(U_bar_bb, np.expand_dims(beta, axis=3)) + alpha_Ubar = U_bar_aa @ np.expand_dims(alpha, axis=3) + beta_Ubar = U_bar_bb @ np.expand_dims(beta, axis=3) # Eq. 46 (seed spatial patterns) self.patterns[0, con_i, :n_seeds] = ( @@ -660,8 +660,8 @@ def _compute_patterns( beta = T_bb @ np.expand_dims(b, axis=3) # filter for targets # Eqs. 46 & 47 of Ewald et al. (2012); i.e. transform filters to channel space - alpha_Ubar = np.matmul(U_bar_aa, alpha) - beta_Ubar = np.matmul(U_bar_bb, beta) + alpha_Ubar = U_bar_aa @ alpha + beta_Ubar = U_bar_bb @ beta # Eq. 14 # seed spatial patterns From 88eafb5de21d2aba90cd23d7ad913888755a6332 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 13 Jun 2024 10:33:05 +0200 Subject: [PATCH 27/38] Shorten tests with kwargs --- .../decoding/tests/test_decomposition.py | 245 +++--------------- 1 file changed, 35 insertions(+), 210 deletions(-) diff --git a/mne_connectivity/decoding/tests/test_decomposition.py b/mne_connectivity/decoding/tests/test_decomposition.py index 7ecc07bf..d76015f1 100644 --- a/mne_connectivity/decoding/tests/test_decomposition.py +++ b/mne_connectivity/decoding/tests/test_decomposition.py @@ -325,6 +325,13 @@ def test_spectral_decomposition_error_catch(method, mode): info=epochs.info, method=method, indices=indices, mode="notamode" ) + base_kwargs = dict( + info=epochs.info, + method=method, + indices=indices, + mode=mode, + ) + # Test fmin & fmax if mode in ["multitaper", "fourier"]: with pytest.raises( @@ -334,14 +341,7 @@ def test_spectral_decomposition_error_catch(method, mode): "'fourier'" ), ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=None, - fmax=fmax, - ) + CoherencyDecomposition(**base_kwargs, fmin=None, fmax=fmax) with pytest.raises( TypeError, match=( @@ -349,55 +349,22 @@ def test_spectral_decomposition_error_catch(method, mode): "'fourier'" ), ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=None, - ) + CoherencyDecomposition(**base_kwargs, fmin=fmin, fmax=None) with pytest.raises( TypeError, match="`fmin` must be an instance of int or float" ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin="15", - fmax=fmax, - ) + CoherencyDecomposition(**base_kwargs, fmin="15", fmax=fmax) with pytest.raises( TypeError, match="`fmax` must be an instance of int or float" ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax="20", - ) + CoherencyDecomposition(**base_kwargs, fmin=fmin, fmax="20") with pytest.raises(ValueError, match="`fmax` must be larger than `fmin`"): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmax, - fmax=fmin, - ) + CoherencyDecomposition(**base_kwargs, fmin=fmax, fmax=fmin) with pytest.raises( ValueError, match="`fmax` cannot be larger than the Nyquist frequency" ): CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=epochs.info["sfreq"] / 2 + 1, + **base_kwargs, fmin=fmin, fmax=epochs.info["sfreq"] / 2 + 1 ) # Test multitaper settings @@ -406,61 +373,27 @@ def test_spectral_decomposition_error_catch(method, mode): TypeError, match="`mt_bandwidth` must be an instance of int, float, or None" ): CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - mt_bandwidth="5", + **base_kwargs, fmin=fmin, fmax=fmax, mt_bandwidth="5" ) with pytest.raises( TypeError, match="`mt_adaptive` must be an instance of bool" ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - mt_adaptive=1, - ) + CoherencyDecomposition(**base_kwargs, fmin=fmin, fmax=fmax, mt_adaptive=1) with pytest.raises( TypeError, match="`mt_low_bias` must be an instance of bool" ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - mt_low_bias=1, - ) + CoherencyDecomposition(**base_kwargs, fmin=fmin, fmax=fmax, mt_low_bias=1) # Test wavelet settings if mode == "cwt_morlet": with pytest.raises( TypeError, match="`cwt_freqs` must not be None if `mode` is 'cwt_morlet'" ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - cwt_freqs=None, - ) + CoherencyDecomposition(**base_kwargs, cwt_freqs=None) with pytest.raises( TypeError, match="`cwt_freqs` must be an instance of array-like" ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - cwt_freqs="1", - ) + CoherencyDecomposition(**base_kwargs, cwt_freqs="1") with pytest.raises( ValueError, match=( @@ -468,10 +401,7 @@ def test_spectral_decomposition_error_catch(method, mode): ), ): CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, + **base_kwargs, cwt_freqs=np.array([epochs.info["sfreq"] / 2 + 1]), cwt_n_cycles=cwt_n_cycles, ) @@ -479,96 +409,40 @@ def test_spectral_decomposition_error_catch(method, mode): TypeError, match="`cwt_n_cycles` must be an instance of int, float, or array-like", ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - cwt_freqs=cwt_freqs, - cwt_n_cycles="5", - ) + CoherencyDecomposition(**base_kwargs, cwt_freqs=cwt_freqs, cwt_n_cycles="5") with pytest.raises( ValueError, match="`cwt_n_cycles` array-like must have the same length as `cwt_freqs`", ): CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, + **base_kwargs, cwt_freqs=cwt_freqs, cwt_n_cycles=np.full(cwt_freqs.shape[0] - 1, 5), ) + base_kwargs.update( + fmin=fmin, fmax=fmax, cwt_freqs=cwt_freqs, cwt_n_cycles=cwt_n_cycles + ) + # Test n_components with pytest.raises( TypeError, match="`n_components` must be an instance of int or None" ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, - n_components="2", - ) + CoherencyDecomposition(**base_kwargs, n_components="2") # Test rank with pytest.raises( TypeError, match="`rank` must be an instance of tuple of ints or None" ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, - rank="2", - ) + CoherencyDecomposition(**base_kwargs, rank="2") with pytest.raises( TypeError, match="`rank` must be an instance of tuple of ints or None" ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, - rank=("2", "2"), - ) + CoherencyDecomposition(**base_kwargs, rank=("2", "2")) with pytest.raises(ValueError, match="`rank` must have length 2"): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, - rank=(2,), - ) + CoherencyDecomposition(**base_kwargs, rank=(2,)) with pytest.raises(ValueError, match="entries of `rank` must be > 0"): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, - rank=(0, 1), - ) + CoherencyDecomposition(**base_kwargs, rank=(0, 1)) with pytest.raises( ValueError, match=( @@ -576,17 +450,7 @@ def test_spectral_decomposition_error_catch(method, mode): "channels in `indices`" ), ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, - rank=(n_seeds + 1, n_targets), - ) + CoherencyDecomposition(**base_kwargs, rank=(n_seeds + 1, n_targets)) with pytest.raises( ValueError, match=( @@ -594,58 +458,19 @@ def test_spectral_decomposition_error_catch(method, mode): "channels in `indices`" ), ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, - rank=(n_seeds, n_targets + 1), - ) + CoherencyDecomposition(**base_kwargs, rank=(n_seeds, n_targets + 1)) # Test n_jobs with pytest.raises(TypeError, match="`n_jobs` must be an instance of int"): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, - n_jobs="1", - ) + CoherencyDecomposition(**base_kwargs, n_jobs="1") # Test verbose with pytest.raises( TypeError, match="`verbose` must be an instance of bool, str, int, or None" ): - CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, - verbose=[True], - ) + CoherencyDecomposition(**base_kwargs, verbose=[True]) - decomp_class = CoherencyDecomposition( - info=epochs.info, - method=method, - indices=indices, - mode=mode, - fmin=fmin, - fmax=fmax, - cwt_freqs=cwt_freqs, - cwt_n_cycles=cwt_n_cycles, - ) + decomp_class = CoherencyDecomposition(**base_kwargs) # TEST BAD FITTING # Test input data format From 163d802ccd79eebfa0f8282ceb7ccf648067b158 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 13 Jun 2024 18:08:00 +0200 Subject: [PATCH 28/38] Add decomp class to main init --- examples/decoding/cohy_decomposition.py | 2 +- mne_connectivity/__init__.py | 1 + mne_connectivity/decoding/tests/test_decomposition.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/cohy_decomposition.py index 58dafb3e..751f145a 100644 --- a/examples/decoding/cohy_decomposition.py +++ b/examples/decoding/cohy_decomposition.py @@ -24,11 +24,11 @@ from mne.datasets.fieldtrip_cmc import data_path from mne_connectivity import ( + CoherencyDecomposition, make_signals_in_freq_bands, seed_target_indices, spectral_connectivity_epochs, ) -from mne_connectivity.decoding import CoherencyDecomposition ######################################################################################## # Background diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index 92c8b7e0..ce18a284 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -24,6 +24,7 @@ TemporalConnectivity, ) from .datasets import make_signals_in_freq_bands +from .decoding import CoherencyDecomposition from .effective import phase_slope_index from .envelope import envelope_correlation, symmetric_orth from .io import read_connectivity diff --git a/mne_connectivity/decoding/tests/test_decomposition.py b/mne_connectivity/decoding/tests/test_decomposition.py index df9cce62..4a39119f 100644 --- a/mne_connectivity/decoding/tests/test_decomposition.py +++ b/mne_connectivity/decoding/tests/test_decomposition.py @@ -4,11 +4,11 @@ from numpy.testing import assert_allclose from mne_connectivity import ( + CoherencyDecomposition, make_signals_in_freq_bands, seed_target_indices, spectral_connectivity_epochs, ) -from mne_connectivity.decoding import CoherencyDecomposition from mne_connectivity.utils import _check_multivariate_indices From 865fa67fa2296f01eaba548ec3b66014735c740a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 13 Jun 2024 18:09:39 +0200 Subject: [PATCH 29/38] Update plotting docstrings --- doc/conf.py | 2 + mne_connectivity/decoding/decomposition.py | 239 +++++++++++++------ mne_connectivity/utils/docs.py | 253 +++++++++++++++++++-- 3 files changed, 406 insertions(+), 88 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index c2765602..5e94cd4f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -124,6 +124,8 @@ "n_estimated_nodes", "n_samples", "n_channels", + "n_patterns", + "n_filters", "Renderer", "n_ytimes", "n_ychannels", diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index d4e2dd69..40e8953b 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -474,7 +474,36 @@ def get_transformed_indices(self): ) @fill_doc - def plot_patterns(self, info, **kwargs): + def plot_patterns( + self, + info, + components=None, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%.1E", + units="AU", + axes=None, + name_format=None, + nrows=1, + ncols="auto", + show=True, + ): """Plot topographic patterns of components. The patterns explain how the measured data was generated from the @@ -484,42 +513,37 @@ def plot_patterns(self, info, **kwargs): Parameters ---------- - %(info_not_none)s - components : float | array of float | None - The patterns to plot. If ``None``, all components will be shown. - %(average_plot_evoked_topomap)s + %(info_decoding_plotting)s + %(components_topomap)s %(ch_type_topomap)s - scalings : dict | float | None - The scalings of the channel types to be applied for plotting. - If None, defaults to ``dict(eeg=1e6, grad=1e13, mag=1e15)``. + %(scalings_topomap)s %(sensors_topomap)s %(show_names_topomap)s %(mask_patterns_topomap)s %(mask_params_topomap)s %(contours_topomap)s %(outlines_topomap)s - %(sphere_topomap_auto)s + %(sphere_topomap)s %(image_interp_topomap)s %(extrapolate_topomap)s %(border_topomap)s %(res_topomap)s %(size_topomap)s %(cmap_topomap)s - %(vlim_plot_topomap)s - %(cnorm)s + %(vlim_topomap)s + %(cnorm_topomap)s %(colorbar_topomap)s - %(cbar_fmt_topomap)s + %(colorbar_format_topomap)s %(units_topomap)s - %(axes_evoked_plot_topomap)s - name_format : str | None - String format for topomap values. ``None`` defaults to f"{method}%%01d". - %(nrows_ncols_topomap)s + %(axes_topomap)s + %(name_format_topomap)s + %(nrows_topomap)s + %(ncols_topomap)s %(show)s Returns ------- - figs : list of instance of matplotlib.figure.Figure - The seed and target figures, respectively. + %(figs_topomap)s """ if self.patterns_ is None: raise RuntimeError( @@ -527,11 +551,67 @@ def plot_patterns(self, info, **kwargs): ) return self._plot_filters_patterns( - (self.patterns_[0].T, self.patterns_[1].T), info, **kwargs + (self.patterns_[0].T, self.patterns_[1].T), + info, + components, + ch_type, + scalings, + sensors, + show_names, + mask, + mask_params, + contours, + outlines, + sphere, + image_interp, + extrapolate, + border, + res, + size, + cmap, + vlim, + cnorm, + colorbar, + cbar_fmt, + units, + axes, + name_format, + nrows, + ncols, + show, ) @fill_doc - def plot_filters(self, info, **kwargs): + def plot_filters( + self, + info, + components=None, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%.1E", + units="AU", + axes=None, + name_format=None, + nrows=1, + ncols="auto", + show=True, + ): """Plot topographic filters of components. The filters are used to extract discriminant neural sources from the measured @@ -541,87 +621,108 @@ def plot_filters(self, info, **kwargs): Parameters ---------- - %(info_not_none)s - components : float | array of float | None - The filters to plot. If ``None``, all components will be shown. - %(average_plot_evoked_topomap)s + %(info_decoding_plotting)s + %(components_topomap)s %(ch_type_topomap)s - scalings : dict | float | None - The scalings of the channel types to be applied for plotting. - If None, defaults to ``dict(eeg=1e6, grad=1e13, mag=1e15)``. + %(scalings_topomap)s %(sensors_topomap)s %(show_names_topomap)s - %(mask_patterns_topomap)s + %(mask_filters_topomap)s %(mask_params_topomap)s %(contours_topomap)s %(outlines_topomap)s - %(sphere_topomap_auto)s + %(sphere_topomap)s %(image_interp_topomap)s %(extrapolate_topomap)s %(border_topomap)s %(res_topomap)s %(size_topomap)s %(cmap_topomap)s - %(vlim_plot_topomap)s - %(cnorm)s + %(vlim_topomap)s + %(cnorm_topomap)s %(colorbar_topomap)s - %(cbar_fmt_topomap)s + %(colorbar_format_topomap)s %(units_topomap)s - %(axes_evoked_plot_topomap)s - name_format : str | None - String format for topomap values. ``None`` defaults to f"{method}%%01d". - %(nrows_ncols_topomap)s + %(axes_topomap)s + %(name_format_topomap)s + %(nrows_topomap)s + %(ncols_topomap)s %(show)s Returns ------- - figs : list of instance of matplotlib.figure.Figure - The seed and target figures, respectively. + %(figs_topomap)s """ if self.filters_ is None: raise RuntimeError( "no filters are available, please call the `fit` method first" ) - return self._plot_filters_patterns(self.filters_, info, **kwargs) + return self._plot_filters_patterns( + self.filters_, + info, + components, + ch_type, + scalings, + sensors, + show_names, + mask, + mask_params, + contours, + outlines, + sphere, + image_interp, + extrapolate, + border, + res, + size, + cmap, + vlim, + cnorm, + colorbar, + cbar_fmt, + units, + axes, + name_format, + nrows, + ncols, + show, + ) def _plot_filters_patterns( self, plot_data, info, - components=None, - average=None, - ch_type=None, - scalings=None, - sensors=True, - show_names=False, - mask=None, - mask_params=None, - contours=6, - outlines="head", - sphere=None, - image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, - border=_BORDER_DEFAULT, - res=64, - size=1, - cmap="RdBu_r", - vlim=(None, None), - cnorm=None, - colorbar=True, - cbar_fmt="%.1E", - units=None, - axes=None, - name_format=None, - nrows=1, - ncols="auto", - show=True, + components, + ch_type, + scalings, + sensors, + show_names, + mask, + mask_params, + contours, + outlines, + sphere, + image_interp, + extrapolate, + border, + res, + size, + cmap, + vlim, + cnorm, + colorbar, + cbar_fmt, + units, + axes, + name_format, + nrows, + ncols, + show, ): """Plot filters/targets for components.""" # Sort inputs _validate_type(info, Info, "`info`", "mne.Info") - if units is None: - units = "AU" if components is None: components = np.arange(self.n_components) @@ -639,7 +740,7 @@ def _plot_filters_patterns( figs.append( evoked.plot_topomap( times=components, - average=average, + average=None, # do not average across independent components ch_type=ch_type, scalings=scalings, sensors=sensors, diff --git a/mne_connectivity/utils/docs.py b/mne_connectivity/utils/docs.py index bc4d0461..74de1a42 100644 --- a/mne_connectivity/utils/docs.py +++ b/mne_connectivity/utils/docs.py @@ -8,7 +8,6 @@ except ImportError: from mne.externals.doccer import indentcount_lines as _indentcount_lines # noqa -from mne.utils.docs import docdict as mne_docdict ############################################################################## # Define our standard documentation entries @@ -60,40 +59,40 @@ """ docdict["mode"] = """ -mode : str (default "multitaper") - The cross-spectral density computation method. Can be ``"multitaper"``, - ``"fourier"``, or ``"cwt_morlet"``. +mode : str (default 'multitaper') + The cross-spectral density computation method. Can be ``'multitaper'``, + ``'fourier'``, or ``'cwt_morlet'``. """ docdict["mt_bandwidth"] = """ mt_bandwidth : int | float | None (default None) The bandwidth of the multitaper windowing function in Hz to use when computing the - cross-spectral density. Only used if ``mode="multitaper"``. + cross-spectral density. Only used if ``mode='multitaper'``. """ docdict["mt_adaptive"] = """ mt_adaptive : bool (default False) Whether to use adaptive weights when combining the tapered spectra in the - cross-spectral density. Only used if ``mode="multitaper"``. + cross-spectral density. Only used if ``mode='multitaper'``. """ docdict["mt_low_bias"] = """ mt_low_bias : bool (default True) Whether to use tapers with over 90 percent spectral concentration within the bandwidth when computing the cross-spectral density. Only used if - ``mode="multitaper"``. + ``mode='multitaper'``. """ docdict["cwt_freqs"] = """ cwt_freqs : array of int or float | None (default None) The frequencies of interest in Hz. Must not be ``None`` and only used if - ``mode="cwt_morlet"``. + ``mode='cwt_morlet'``. """ docdict["cwt_n_cycles"] = """ cwt_n_cycles : int | float | array of int or float (default 7) The number of cycles to use when constructing the Morlet wavelets. Fixed number or - one per frequency. Only used if ``mode=cwt_morlet``. + one per frequency. Only used if ``mode='cwt_morlet'``. """ docdict["coh"] = "'coh' : Coherence" @@ -185,7 +184,7 @@ ``None``. """ -# Decoding +# Decoding initialisation docdict["info_decoding"] = """ info : mne.Info Information about the data which will be decomposed and transformed, such as that @@ -197,20 +196,20 @@ method : str The multivariate method to use for the decomposition. Can be: - * ``"cacoh"`` - Canonical Coherency (CaCoh) :footcite:`VidaurreEtAl2019` - * ``"mic"`` - Maximised Imaginary part of Coherency (MIC) :footcite:`EwaldEtAl2012` + * ``'cacoh'`` - Canonical Coherency (CaCoh) :footcite:`VidaurreEtAl2019` + * ``'mic'`` - Maximised Imaginary part of Coherency (MIC) :footcite:`EwaldEtAl2012` """ docdict["fmin_decoding"] = """ fmin : int | float | None (default None) - The lowest frequency of interest in Hz. Must not be ``None`` and only used if - ``mode in ["multitaper", "fourier"]``. + The lowest frequency of interest in Hz. Must not be `None` and only used if + ``mode in ['multitaper', 'fourier']``. """ docdict["fmax_decoding"] = """ fmax : int | float | None (default None) - The highest frequency of interest in Hz. Must not be ``None`` and only used if - ``mode in ["multitaper", "fourier"]``. + The highest frequency of interest in Hz. Must not be `None` and only used if + ``mode in ['multitaper', 'fourier']``. """ docdict["indices_decoding"] = """ @@ -240,6 +239,7 @@ that of the data may reduce the degree of overfitting when computing the filters. """ +# Decoding attrs docdict["filters_"] = """ filters_ : tuple of array, shape=(n_signals, n_components) A tuple of two arrays containing the spatial filters for transforming the seed and @@ -252,9 +252,224 @@ filters for the seed and target data, respectively. """ -for key, val in mne_docdict.items(): - if key not in docdict: - docdict[key] = val +# Decoding plotting +docdict["info_decoding_plotting"] = """ +info : mne.Info + Information about the sensors of the data which has been decomposed, such as that + coming from an :class:`mne.Epochs` object. +""" + +# Topomaps +docdict["components_topomap"] = """ +components : int | array of int | None (default None) + The components to plot. If `None`, all components are shown. +""" + +docdict["ch_type_topomap"] = """ +ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg' | None (default None) + The channel type to plot. For ``'grad'``, the gradiometers are collected in pairs + and the RMS for each pair is plotted. If `None`, the first available channel type + from the order shown above is used. +""" + +docdict["scalings_topomap"] = """ +scalings : dict | float | None (default None) + The scalings of the channel types to be applied for plotting. If `None`, uses + ``dict(eeg=1e6, grad=1e13, mag=1e15)``. +""" + +docdict["sensors_topomap"] = """ +sensors : bool | str (default True) + Whether to add markers for sensor locations. If `str`, should be a valid + matplotlib format string (e.g., ``'r+'`` for red plusses; see the Notes section of + :meth:`~matplotlib.axes.Axes.plot`). If `True`, black circles are used. +""" + +docdict["show_names_topomap"] = """ +show_names : bool | callable (default False) + Whether to show channel names next to each sensor marker. If `callable`, channel + names will be formatted using the callable; e.g., to delete the prefix 'MEG ' from + all channel names, pass the function ``lambda x: x.replace('MEG ', '')``. If + ``mask`` is not `None`, only non-masked sensor names will be shown. +""" + +docdict["mask_filters_topomap"] = """ +mask : array of bool, shape=(n_channels, n_filters) | None (default None) + An array specifying channel-filter combinations to highlight with a distinct + plotting style. Array elements set to `True` will be plotted with the parameters + given in ``mask_params``. If `None`, no combinations will be highlighted. +""" +docdict["mask_patterns_topomap"] = """ +mask : array of bool, shape=(n_channels, n_patterns) | None (default None) + An array specifying channel-pattern combinations to highlight with a distinct + plotting style. Array elements set to `True` will be plotted with the parameters + given in ``mask_params``. If `None`, no combinations will be highlighted. +""" + +docdict["mask_params_topomap"] = """ +mask_params : dict | None (default None) + The plotting parameters for distinct combinations given in ``mask``. + Default `None` equals:: + + dict(marker='o', markerfacecolor='w', markeredgecolor='k', + linewidth=0, markersize=4) +""" + +docdict["contours_topomap"] = """ +contours : int | array (default 6) + The number of contour lines to draw. If ``0``, no contours will be drawn. If a + positive integer, that number of contour levels are chosen using the matplotlib tick + locator (may sometimes be inaccurate, use array for accuracy). If an array-like, the + values are used as the contour levels. The values should be in µV for EEG, fT for + magnetometers and fT/m for gradiometers. If ``colorbar=True``, the colorbar will + have ticks corresponding to the contour levels. +""" + +docdict["outlines_topomap"] = """ +outlines : 'head' | dict | None (default 'head') + The outlines to be drawn. If 'head', the default head scheme will be drawn. If dict, + each key refers to a tuple of x and y positions, the values in 'mask_pos' will serve + as image mask. Alternatively, a matplotlib patch object can be passed for advanced + masking options, either directly or as a function that returns patches (required for + multi-axis plots). If `None`, nothing will be drawn. +""" + +docdict["sphere_topomap"] = """ +sphere : float | array | mne.bem.ConductorModel | None | 'auto' | 'eeglab' (default None) + The sphere parameters to use for the head outline. Can be array-like of shape (4,) + to give the X/Y/Z origin and radius in meters, or a single float to give just the + radius (origin assumed 0, 0, 0). Can also be an instance of a spherical + :class:`~mne.bem.ConductorModel` to use the origin and radius from that object. If + ``'auto'`` the sphere is fit to digitization points. If ``'eeglab'`` the head circle + is defined by EEG electrodes ``'Fpz'``, ``'Oz'``, ``'T7'``, and ``'T8'`` (if + ``'Fpz'`` is not present, it will be approximated from the coordinates of ``'Oz'``). + `None` is equivalent to ``'auto'`` when enough extra digitization points are + available, and (0, 0, 0, 0.95) otherwise. +""" # noqa E501 + +docdict["image_interp_topomap"] = """ +image_interp : str (default 'cubic') + The image interpolation to be used. Options are ``'cubic'`` to use + :class:`scipy.interpolate.CloughTocher2DInterpolator`, ``'nearest'`` to use + :class:`scipy.spatial.Voronoi`, or ``'linear'`` to use + :class:`scipy.interpolate.LinearNDInterpolator`. +""" + +docdict["extrapolate_topomap"] = """ +extrapolate : str + The extrapolation options. Can be one of: + + - ``'box'`` + Extrapolate to four points placed to form a square encompassing all data points, + where each side of the square is three times the range of the data in the + respective dimension. + - ``'local'`` (default for MEG sensors) + Extrapolate only to nearby points (approximately to points closer than median + inter-electrode distance). This will also set the mask to be polygonal based on + the convex hull of the sensors. + - ``'head'`` (default for non-MEG sensors) + Extrapolate out to the edges of the clipping circle. This will be on the head + circle when the sensors are contained within the head circle, but it can extend + beyond the head when sensors are plotted outside the head circle. +""" + +docdict["border_topomap"] = """ +border : float | 'mean' (default 'mean') + The value to extrapolate to on the topomap borders. If ``'mean'``, each extrapolated + point has the average value of its neighbours. +""" + +docdict["res_topomap"] = """ +res : int (default 64) + The resolution of the topomap image (number of pixels along each side). +""" + +docdict["size_topomap"] = """ +size : int | float (default 1) + The side length of each subplot in inches. +""" + +docdict["cmap_topomap"] = """ +cmap : str | matplotlib.colors.Colormap | (matplotlib.colors.Colormap, bool) | 'interactive' | None (default 'RdBu_r') + The colormap to use. If a `str`, should be a valid matplotlib colormap. If a + `tuple`, the first value is `matplotlib.colors.Colormap` object to use and the + second value is a boolean defining interactivity. In interactive mode the colors are + adjustable by clicking and dragging the colorbar with left and right mouse button. + Left mouse button moves the scale up and down and right mouse button adjusts the + range. Hitting space bar resets the range. Up and down arrows can be used to change + the colormap. If `None`, ``'Reds'`` is used for data that is either all positive or + all negative, and ``'RdBu_r'`` is used otherwise. ``'interactive'`` is equivalent to + ``(None, True)``. + + .. warning:: Interactive mode works smoothly only for a small amount + of topomaps. Interactive mode is disabled by default for more than + 2 topomaps. +""" # noqa E501 + +docdict["vlim_topomap"] = """ +vlim : tuple of length 2 (default (None, None)) + The lower and upper colormap bounds, respectively. If both entries are `None`, sets + bounds to ``(min(data), max(data))``. If one entry is `None`, the corresponding + boundary is set at the min/max of the data. +""" + +docdict["cnorm_topomap"] = """ +cnorm : matplotlib.colors.Normalize | None (default None) + How to normalize the colormap. If `None`, standard linear normalization is used. If + not `None`, ``vlim`` is ignored. See the :ref:`Matplotlib docs + ` for more details on colormap normalization. +""" + +docdict["colorbar_topomap"] = """ +colorbar : bool (default True) + Whether to plot a colorbar in the rightmost column of the figure. +""" + +docdict["colorbar_format_topomap"] = r""" +cbar_fmt : str (default '%.1E') + The formatting string for colorbar tick labels. See :ref:`formatspec` for details. +""" + +docdict["units_topomap"] = """ +units : str (default 'AU') + The units for the colorbar label. Ignored if ``colorbar=False``. +""" + +docdict["axes_topomap"] = """ +axes : matplotlib.axes.Axes | list of matplotlib.axes.Axes | None (default None) + The axes to plot to. If `None`, a new figure will be created with the correct number + of axes. If not `None`, the number of axes must match ``components``. +""" + +docdict["name_format_topomap"] = r""" +name_format : str | None (default None) + The string format for axes titles. If `None`, uses f"{method}%%01d", i.e. the + method name followed by the component number. +""" + +docdict["nrows_topomap"] = """ +nrows : int | 'auto' (default 'auto') + The number of rows of components to plot. If ``'auto'``, the necessary number will + be inferred. +""" + +docdict["ncols_topomap"] = """ +ncols : int | 'auto' (default 'auto') + The number of columns of components to plot. If ``'auto'``, the necessary number + will be inferred. If ``nrows='auto'`` and ``ncols='auto'``, becomes ``nrows=1, + ncols='auto'``. +""" + +docdict["figs_topomap"] = """ +figs : list of matplotlib.figure.Figure + The seed and target figures, respectively. +""" + +docdict["show"] = """ +show : bool (default True) + Whether to show the figure. +""" + docdict_indented = dict() # type: ignore From ccdad63d6cd5fe73c3b667f26b68a4a9cca7ad10 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 13 Jun 2024 18:13:40 +0200 Subject: [PATCH 30/38] Add docs authorship --- mne_connectivity/utils/docs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_connectivity/utils/docs.py b/mne_connectivity/utils/docs.py index 74de1a42..0e8376fd 100644 --- a/mne_connectivity/utils/docs.py +++ b/mne_connectivity/utils/docs.py @@ -1,5 +1,6 @@ """The documentation functions.""" # Authors: Eric Larson +# Thomas S. Binns # # License: BSD (3-clause) From c026d3e2e40afd9e01cd8c520b6254407bc79aa9 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 19 Jun 2024 16:58:43 +0200 Subject: [PATCH 31/38] Archive old example --- .../decoding/{cohy_decomposition.py => OLD_cohy_decomposition.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/decoding/{cohy_decomposition.py => OLD_cohy_decomposition.py} (100%) diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/OLD_cohy_decomposition.py similarity index 100% rename from examples/decoding/cohy_decomposition.py rename to examples/decoding/OLD_cohy_decomposition.py From 03186a0b2e5f360b77ea22fc8e0eeb2aa5a70960 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 20 Jun 2024 12:19:24 +0200 Subject: [PATCH 32/38] Update decomp example formatting --- examples/decoding/cohy_decomposition.py | 26 ++++++++++++------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/cohy_decomposition.py index 3d4ad72e..b916f6ec 100644 --- a/examples/decoding/cohy_decomposition.py +++ b/examples/decoding/cohy_decomposition.py @@ -3,10 +3,10 @@ Multivariate decomposition for efficient connectivity analysis ============================================================== -This example demonstrates how the tools in the decoding module can be used to -decompose data into the most relevant components of connectivity and used for -a computationally efficient multivariate analysis of connectivity, such as in -brain-computer interface (BCI) applications. +This example demonstrates how the tools in the decoding module can be used to decompose +data into the most relevant components of connectivity and used for a computationally +efficient multivariate analysis of connectivity, such as in brain-computer interface +(BCI) applications. """ # Author: Thomas S. Binns @@ -30,14 +30,13 @@ ######################################################################################## # Background # ---------- -# -# Multivariate forms of signal analysis allow you to simultaneously consider -# the activity of multiple signals. In the case of connectivity, the -# interaction between multiple sensors can be analysed at once and the strongest -# components of this interaction captured in a lower-dimensional set of connectivity -# spectra. This approach brings not only practical benefits (e.g. easier -# interpretability of results from the dimensionality reduction), but can also offer -# methodological improvements (e.g. enhanced signal-to-noise ratio and reduced bias). +# Multivariate forms of signal analysis allow you to simultaneously consider the +# activity of multiple signals. In the case of connectivity, the interaction between +# multiple sensors can be analysed at once and the strongest components of this +# interaction captured in a lower-dimensional set of connectivity spectra. This approach +# brings not only practical benefits (e.g. easier interpretability of results from the +# dimensionality reduction), but can also offer methodological improvements (e.g. +# enhanced signal-to-noise ratio and reduced bias). # # Coherency-based methods are popular approaches for analysing connectivity, capturing # correlations between signals in the frequency domain. Various coherency-based @@ -244,8 +243,7 @@ # connectivity is present. This problem can be mitigated by fitting filters to only # those frequencies where you expect connectivity to be present, e.g. as is done with # the decomposition class. - -######################################################################################## +# # In addition to assessing the validity of the approach, we can also look at the time # taken to run the analysis. Doing so, we see that the decomposition class is much # faster than the ``spectral_connectivity_...()`` functions, thanks to the fact that the From 1314cb478b71a5105aa0d720f41007dabfb14835 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 20 Jun 2024 12:19:57 +0200 Subject: [PATCH 33/38] Add decomp plotting example Remove old decomp example Add reference to plotting example --- examples/decoding/OLD_cohy_decomposition.py | 613 ------------------ examples/decoding/cohy_decomposition.py | 14 + .../decoding/cohy_decomposition_plotting.py | 170 +++++ 3 files changed, 184 insertions(+), 613 deletions(-) delete mode 100644 examples/decoding/OLD_cohy_decomposition.py create mode 100644 examples/decoding/cohy_decomposition_plotting.py diff --git a/examples/decoding/OLD_cohy_decomposition.py b/examples/decoding/OLD_cohy_decomposition.py deleted file mode 100644 index 751f145a..00000000 --- a/examples/decoding/OLD_cohy_decomposition.py +++ /dev/null @@ -1,613 +0,0 @@ -""" -============================================================== -Multivariate decomposition for efficient connectivity analysis -============================================================== - -This example demonstrates how the tools in the decoding module can be used to -decompose data into the most relevant components of connectivity and used for -a computationally efficient multivariate analysis of connectivity, such as in -brain-computer interface (BCI) applications. -""" - -# Author: Thomas S. Binns -# License: BSD (3-clause) -# sphinx_gallery_thumbnail_number = 2 - -# %% - -import time - -import mne -import numpy as np -from matplotlib import pyplot as plt -from mne import make_fixed_length_epochs -from mne.datasets.fieldtrip_cmc import data_path - -from mne_connectivity import ( - CoherencyDecomposition, - make_signals_in_freq_bands, - seed_target_indices, - spectral_connectivity_epochs, -) - -######################################################################################## -# Background -# ---------- -# -# Multivariate forms of signal analysis allow you to simultaneously consider -# the activity of multiple signals. In the case of connectivity, the -# interaction between multiple sensors can be analysed at once and the strongest -# components of this interaction captured in a lower-dimensional set of connectivity -# spectra. This approach brings not only practical benefits (e.g. easier -# interpretability of results from the dimensionality reduction), but can also offer -# methodological improvements (e.g. enhanced signal-to-noise ratio and reduced bias). -# -# Coherency-based methods are popular approaches for analysing connectivity, capturing -# correlation between signals in the frequency domain. Various coherency-based -# multivariate methods exist, including: canonical coherency (CaCoh; multivariate -# measure of coherency/coherence) :footcite:`VidaurreEtAl2019` ; and maximised imaginary -# coherency (MIC; multivariate measure of the imaginary part of coherency) -# :footcite:`EwaldEtAl2012`. -# -# These methods are described in detail in the following examples: -# -# - comparison of coherency-based methods - :doc:`../compare_coherency_methods` -# - CaCoh - :doc:`../cacoh` -# - MIC - :doc:`../mic_mim` -# -# The CaCoh and MIC methods work by finding spatial filters that decompose the data into -# components of connectivity, and applying them to the data. With the implementations -# offered in :func:`~mne_connectivity.spectral_connectivity_epochs` and -# :func:`~mne_connectivity.spectral_connectivity_time`, the filters are fit for each -# frequency separately, and the filters are only applied to the same data they are fit -# on. -# -# Unfortunately, fitting filters for each frequency bin can be computationally -# expensive, which may prohibit the use of these techniques, e.g. in real-time BCI -# setups where the rapid analysis of data is paramount, or even in offline analyses -# with huge datasets. -# -# These issues are addressed by the -# :class:`~mne_connectivity.decoding.CoherencyDecomposition` class of the decoding -# module. Here, the filters are fit for a given frequency band collectively (not each -# frequency bin!) and are stored, allowing them to be applied to the same data they were -# fit on (e.g. for offline analyses of huge datasets) or to new data (e.g. for online -# analyses of streamed data). -# -# In this example, we show how the tools of the decoding module compare to the standard -# ``spectral_connectivity_...()`` functions in terms of their run time, and their -# ability to decompose data into connectivity components. - -######################################################################################## -# Case 1: Fitting to and transforming different data -# -------------------------------------------------- -# -# We start by simulating some connectivity between two groups of signals at 15-20 Hz as -# 60 two-second-long epochs. To demonstrate the approach of fitting filters to one set -# of data and applying to another set of data, we will treat the first 30 epochs as the -# data on which we train the filters, and the last 30 epochs as the data we transform. -# We will use the CaCoh method, since zero time-lag interactions are not present (See -# :doc:`../compare_coherency_methods` for more information). - -# %% - -N_SEEDS = 10 -N_TARGETS = 15 - -FMIN = 15 -FMAX = 20 - -N_EPOCHS = 60 - -epochs = make_signals_in_freq_bands( - n_seeds=N_SEEDS, - n_targets=N_TARGETS, - freq_band=(FMIN, FMAX), - n_epochs=N_EPOCHS, - n_times=200, - sfreq=100, - snr=0.2, - rng_seed=44, -) - -indices = (np.arange(N_SEEDS), np.arange(N_TARGETS) + N_SEEDS) - -######################################################################################## -# First, we use the standard CaCoh approach in -# :func:`~mne_connectivity.spectral_connectivity_epochs` to visualise the connectivity -# in the first 30 epochs. We also plot bivariate coherence to demonstrate the -# signal-to-noise enhancements this multivariate approach offers. As expected, we see a -# peak in connectivity at 15-20 Hz decomposed by the spatial filters. - -# %% - -# Connectivity profile of first 30 epochs (filters fit to these epochs) -con_cacoh_first = spectral_connectivity_epochs( - epochs[: N_EPOCHS // 2], - method="cacoh", - indices=([indices[0]], [indices[1]]), - fmin=5, - fmax=35, - rank=([3], [3]), -) -ax = plt.subplot(111) -ax.plot(con_cacoh_first.freqs, np.abs(con_cacoh_first.get_data()[0]), label="CaCoh") - -# Connectivity profile of first 30 epochs (no filters) -con_coh_first = spectral_connectivity_epochs( - epochs[: N_EPOCHS // 2], - method="coh", - indices=seed_target_indices(indices[0], indices[1]), - fmin=5, - fmax=35, -) -ax.plot(con_coh_first.freqs, np.mean(con_coh_first.get_data(), axis=0), label="Coh") -ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") -ax.set_xlabel("Frequency (Hz)") -ax.set_ylabel("Connectivity (A.U.)") -ax.set_title("Epochs 0-30") -plt.legend() -plt.show() - -######################################################################################## -# The goal of the decoding module approach is to use the information from the first 30 -# epochs to fit the filters, and then use these filters to extract the same components -# from the last 30 epochs. -# -# For this, we instantiate the -# :class:`~mne_connectivity.decoding.CoherencyDecomposition` class with: the -# information about the data being fit/transformed (using an :class:`~mne.Info` object); -# the type of connectivity we want to decompose (here CaCoh); the frequency band of the -# components we want to decompose (here 15-20 Hz); and the channel indices of the seeds -# and targets. -# -# Next, we call the :meth:`~mne_connectivity.decoding.CoherencyDecomposition.fit` -# method, passing in the first 30 epochs of data we want to fit the filters to. Once the -# filters are fit, we can apply them to the last 30 epochs using the -# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.transform` method. -# -# The transformed data has shape ``(epochs x components*2 x times)``, where the new -# 'channels' are organised as the seed components, then target components. For -# convenience, the -# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.get_transformed_indices` -# method can be used to get the ``indices`` of the transformed data for use in the -# ``spectral_connectivity_...()`` functions. - -# %% - -# Fit filters to first 30 epochs -cacoh = CoherencyDecomposition( - info=epochs.info, - method="cacoh", - indices=indices, - mode="multitaper", - fmin=FMIN, - fmax=FMAX, - rank=(3, 3), -) -cacoh.fit(epochs[: N_EPOCHS // 2].get_data()) - -# Use filters to transform data from last 30 epochs -epochs_transformed = cacoh.transform(epochs[N_EPOCHS // 2 :].get_data()) -indices_transformed = cacoh.get_transformed_indices() - -######################################################################################## -# We can now visualise the connectivity in the last 30 epochs of the transformed data, -# which for reference we will compare to connectivity in the last 30 epochs using -# filters fit to the data itself, as well as bivariate coherence to again demonstrate -# the signal-to-noise enhancements the multivariate approach offers. -# -# To compute connectivity of the transformed data, it is simply a case of passing to the -# ``spectral_connectivity_...()`` functions: the transformed data; the indices -# returned from -# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.get_transformed_indices`; and -# the corresponding bivariate method (``"coh"`` and ``"cohy"`` for CaCoh; ``"imcoh"`` -# for MIC). -# -# As you can see, the connectivity profile of the transformed data using filters fit on -# the first 30 epochs is very similar to the connectivity profile when using filters fit -# on the last 30 epochs. This shows that the filters are generalisable, able to extract -# the same components of connectivity which they were trained on from new data. - -# %% - -# Connectivity profile of last 30 epochs (filters fit to these epochs) -con_cacoh_last = spectral_connectivity_epochs( - epochs[N_EPOCHS // 2 :], - method="cacoh", - indices=([indices[0]], [indices[1]]), - fmin=5, - fmax=35, - rank=([3], [3]), -) -ax = plt.subplot(111) -ax.plot( - con_cacoh_last.freqs, - np.abs(con_cacoh_last.get_data()[0]), - label="CaCoh (filters trained\non epochs 30-60)", -) - -# Connectivity profile of last 30 epochs (no filters) -con_coh_last = spectral_connectivity_epochs( - epochs[N_EPOCHS // 2 :], - method="coh", - indices=seed_target_indices(indices[0], indices[1]), - fmin=5, - fmax=35, -) -ax.plot( - con_coh_last.freqs, np.mean(np.abs(con_coh_last.get_data()), axis=0), label="Coh" -) - -# Connectivity profile of last 30 epochs (filters fit to first 30 epochs) -con_cacoh_last_from_first = spectral_connectivity_epochs( - epochs_transformed, - method="coh", - indices=indices_transformed, - fmin=5, - fmax=35, - sfreq=epochs.info["sfreq"], -) -ax.plot( - con_cacoh_last_from_first.freqs, - np.abs(con_cacoh_last_from_first.get_data()[0]), - label="CaCoh (filters trained\non epochs 0-30)", -) -ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") -ax.set_xlabel("Frequency (Hz)") -ax.set_ylabel("Connectivity (A.U.)") -ax.set_title("Epochs 30-60") -plt.legend() -plt.show() - -######################################################################################## -# In addition to assessing the validity of the approach, we can also look at the time -# taken to run the analysis. Below we present a scenario resembling an online sliding -# window approach typical of a BCI system. We consider the first 30 epochs to be the -# training data that the filters should be fit to, and the last 30 epochs to be the -# windows of data that the filters should be applied to, transforming and computing the -# connectivity of each window (epoch) of data sequentially. -# -# Doing so, we see that once the filters have been fit, it takes only a few milliseconds -# to transform each window of data and compute its connectivity. - -# %% - -cacoh = CoherencyDecomposition( - info=epochs.info, - method="cacoh", - indices=indices, - mode="multitaper", - fmin=FMIN, - fmax=FMAX, - rank=(3, 3), -) - -# Time fitting of filters -start_fit = time.time() -cacoh.fit(epochs[: N_EPOCHS // 2].get_data()) -fit_duration = (time.time() - start_fit) * 1000 - -# Time transforming data of each epoch iteratively -start_transform = time.time() -for epoch in epochs[N_EPOCHS // 2 :]: - epoch_transformed = cacoh.transform(epoch) - spectral_connectivity_epochs( - np.expand_dims(epoch_transformed, axis=0), - method="coh", - indices=indices_transformed, - fmin=5, - fmax=35, - sfreq=epochs.info["sfreq"], - ) -transform_duration = (time.time() - start_transform) * 1000 - -# %% - -print(f"Time to fit filters: {fit_duration:.0f} ms") -print(f"Time to transform data and compute connectivity: {transform_duration:.0f} ms") -print(f"Total time: {fit_duration + transform_duration:.0f} ms") - -print( - "\nTime to transform data and compute connectivity per epoch (window): ", - f"{transform_duration/(N_EPOCHS//2):.0f} ms", -) - -######################################################################################## -# In contrast, here we follow the same sequential window approach, but fit filters to -# each window separately rather than using a pre-computed set. Naturally, the process of -# fitting and transforming the data for each window is considerably slower. -# -# Furthermore, given the noisy nature of single windows of data, there is a risk of -# overfitting the filters to this noise as opposed to the genuine interaction(s) of -# interest. This risk is mitigated by performing the initial filter fitting on a larger -# set of data. - -# %% - -# Time fitting and transforming data of each epoch iteratively -start_fit_transform = time.time() -for epoch in epochs[N_EPOCHS // 2 :]: - spectral_connectivity_epochs( - np.expand_dims(epoch, axis=0), - method="cacoh", - indices=([indices[0]], [indices[1]]), - fmin=5, - fmax=35, - sfreq=epochs.info["sfreq"], - rank=([3], [3]), - ) -fit_transform_duration = (time.time() - start_fit_transform) * 1000 - -# %% - -print( - f"Time to fit, transform, and compute connectivity: {fit_transform_duration:.0f} ms" -) - -print( - "\nTime to fit, transform, and compute connectivity per epoch (window): ", - f"{fit_transform_duration/(N_EPOCHS//2):.0f} ms", -) - -######################################################################################## -# As a side note, it is important to consider that a multivariate approach may be as -# fast or even faster than a bivariate approach, depending on the number of connections -# and degree of rank subspace projection being performed. - -# %% - -# Time transforming data of each epoch iteratively -start = time.time() -for epoch in epochs[N_EPOCHS // 2 :]: - spectral_connectivity_epochs( - np.expand_dims(epoch, axis=0), - method="coh", - indices=seed_target_indices(indices[0], indices[1]), - fmin=5, - fmax=35, - sfreq=epochs.info["sfreq"], - ) -duration = (time.time() - start) * 1000 - -# %% - -print(f"Time to compute connectivity: {duration:.0f} ms") - -print( - "\nTime to compute connectivity per epoch (window): ", - f"{duration/(N_EPOCHS//2):.0f} ms", -) - -######################################################################################## -# Case 2: Fitting to and transforming the same data -# ------------------------------------------------- -# -# As mentioned above, the decoding module classes can also be used to transform the same -# data the filters are fit to. This is a similar process to that of the -# ``spectral_connectivity_...()`` functions, but with the increased efficiency of -# fitting filters to a single frequency band as opposed to each frequency bin. -# -# To demonstrate this approach, we will load some example MEG data and divide it into -# two-second-long epochs. We designate the left hemisphere sensors as the seeds and the -# right hemisphere sensors as the targets. Since this is sensor-space data, we will use -# the MIC method to analyse connectivity given its resilience to zero time-lag -# interactions (See :doc:`../compare_coherency_methods` for more information). - -# %% - -raw = mne.io.read_raw_ctf(data_path() / "SubjectCMC.ds") -raw.pick("mag") -raw.crop(50.0, 110.0).load_data() -raw.notch_filter(50) -raw.resample(100) - -epochs = make_fixed_length_epochs(raw, duration=2.0).load_data() - -# left hemisphere sensors -seeds = [idx for idx, ch_info in enumerate(epochs.info["chs"]) if ch_info["loc"][0] < 0] -# right hemisphere sensors -targets = [ - idx for idx, ch_info in enumerate(epochs.info["chs"]) if ch_info["loc"][0] > 0 -] - -######################################################################################## -# There are two equivalent options for fitting and transforming the same data: 1) -# passing the data to the :meth:`~mne_connectivity.decoding.CoherencyDecomposition.fit` -# and :meth:`~mne_connectivity.decoding.CoherencyDecomposition.transform` methods -# sequentially; or 2) using the combined -# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.fit_transform` method. -# -# We use the latter approach below, fitting the filters to the 15-20 Hz band and using -# the ``"imcoh"`` method in the call to the ``spectral_connectivity_...()`` functions. -# Plotting the results, we see a peak in connectivity at 15-20 Hz. - -# %% - -mic = CoherencyDecomposition( - info=epochs.info, - method="mic", - indices=(seeds, targets), - mode="multitaper", - fmin=FMIN, - fmax=FMAX, - rank=(3, 3), -) - -start = time.time() -epochs_transformed = mic.fit_transform(epochs.get_data()) - -con_mic_class = spectral_connectivity_epochs( - epochs_transformed, - method="imcoh", - indices=mic.get_transformed_indices(), - fmin=5, - fmax=30, - sfreq=epochs.info["sfreq"], -) -class_duration = time.time() - start - -ax = plt.subplot(111) -ax.plot( - con_mic_class.freqs, - np.abs(con_mic_class.get_data()[0]), - color=plt.rcParams["axes.prop_cycle"].by_key()["color"][2], - label="MIC (decomposition\nclass)", -) -ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") -ax.set_xlabel("Frequency (Hz)") -ax.set_ylabel("Connectivity (A.U.)") -plt.legend() -plt.show() - -######################################################################################## -# For comparison, we can also use the standard approach of the -# ``spectral_connectivity_...()`` functions, which shows a very similar connectivity -# profile in the 15-20 Hz frequency range (but not identical due to band- vs. bin-wise -# filter fitting approaches). Bivariate coherence is again shown to demonstrate the -# signal-to-noise enhancements the multivariate approach offers. - -# %% - -start = time.time() -con_mic_func = spectral_connectivity_epochs( - epochs, - method="mic", - indices=([seeds], [targets]), - fmin=5, - fmax=30, - rank=([3], [3]), -) -func_duration = time.time() - start - -con_imcoh = spectral_connectivity_epochs( - epochs, - method="imcoh", - indices=seed_target_indices(seeds, targets), - fmin=5, - fmax=30, - rank=([3], [3]), -) - -ax = plt.subplot(111) -ax.plot( - con_mic_func.freqs, - np.abs(con_mic_func.get_data()[0]), - label="MIC (standard\nfunction)", -) -ax.plot( - con_imcoh.freqs, - np.mean(np.abs(con_imcoh.get_data()), axis=0), - label="ImCoh", -) -ax.plot( - con_mic_class.freqs, - np.abs(con_mic_class.get_data()[0]), - label="MIC (decomposition\nclass)", -) -ax.axvspan(FMIN, FMAX, color="grey", alpha=0.2, label="Fitted freq. band") -ax.set_xlabel("Frequency (Hz)") -ax.set_ylabel("Connectivity (A.U.)") -plt.legend() -plt.show() - -######################################################################################## -# As with the previous example, we can also compare the time taken to run the analyses. -# Here we see that the decomposition class is much faster than the -# ``spectral_connectivity_...()`` functions, thanks to the fact that the filters are fit -# to the entire frequency band and not each frequency bin. - -# %% - -print( - "Time to fit, transform, and compute connectivity (decomposition class): " - f"{class_duration:.2f} s" -) -print( - f"Time to fit, transform, and compute connectivity (standard function): " - f"{func_duration:.2f} s" -) - -######################################################################################## -# Visualising filters and patterns -# -------------------------------- -# In addition to the connectivity scores, useful insights about the data can be gained -# by visualising the topographies of the filters and patterns, which represent two -# complementary aspects: -# -# - The filters represent how the connectivity sources are extracted from the channel -# data, akin to an inverse model. -# - The patterns represent how the channel data is formed by the connectivity sources, -# akin to a forward model. -# -# This distinction is discussed further in Haufe *et al.* (2014) -# :footcite:`HaufeEtAl2014`, but in short: **the patterns should be used to interpret -# the contribution of distinct brain regions/sensors to a given component of -# connectivity**. Accordingly, keep in mind that the filters and patterns are not a -# replacement for source reconstruction, as without this the patterns will still only -# tell you about the spatial contributions of sensors, not underlying brain regions, -# to connectivity. -# -# Visualising these topographies can be done using the -# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.plot_filters` and -# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.plot_patterns` methods. -# -# When interpreting patterns, note that the absolute value reflects the strength of the -# contribution to connectivity, however the sign differences can be used to visualise -# the orientation of the underlying dipole sources. The spatial patterns are **not** -# bound between :math:`[-1, 1]`. -# -# Plotting the patterns below, we can infer the existence of postcentral, generally -# medial dipole sources contributing to the connectivity between sensors over left and -# right hemispheres at 15-20 Hz. - -# %% - -# Plot patterns -mic.plot_patterns(epochs.info, sensors="m.", size=2) - -######################################################################################## -# For comparison we can also plot the filters, and here we see that they show a very -# similar topography to the patterns. However, this is not always the case, and you -# should never confuse the information represented by the filters and patterns, which -# can lead to very incorrect interpretations of the data :footcite:`HaufeEtAl2014`. - -# %% - -# Plot filters -mic.plot_filters(epochs.info, sensors="m.", size=2) - -######################################################################################## -# Limitations -# ----------- -# Finally, it is important to discuss a key limitation of the decoding module approach: -# the need to define a specific frequency band. Defining this band requires some -# existing knowledge about your data or the oscillatory activity you are studying. This -# insight may come from a pilot study where a frequency band of interest was identified, -# a canonical frequency band defined in the literature, etc... In contrast, by fitting -# filters to each frequency bin, the standard ``spectral_connectivity_...()`` functions -# are more flexible. -# -# Additionally, by applying filters fit on one set of data to another, you are assuming -# that the connectivity components the filters are designed to extract are consistent -# across the two sets of data. However, this may not be the case if you are applying the -# filters to data from a distinct functional state where the spatial distribution of the -# components differs. Again, by fitting filters to each new set of data passed in, the -# standard ``spectral_connectivity_...()`` functions are more flexible, extracting -# whatever connectivity components are present in that data. -# -# On these points, we note that the ``spectral_connectivity_...()`` functions complement -# the decoding module classes well, offering a tool by which to explore your data to: -# identify possible frequency bands of interest; and identify the spatial distributions -# of connectivity components to determine if they are consistent across different -# portions of the data. -# -# Ultimately, there are distinct advantages and disadvantages to both approaches, and -# one may be more suitable than the other depending on your use case. - -######################################################################################## -# References -# ---------- -# .. footbibliography:: - -# %% diff --git a/examples/decoding/cohy_decomposition.py b/examples/decoding/cohy_decomposition.py index b916f6ec..26a2e18c 100644 --- a/examples/decoding/cohy_decomposition.py +++ b/examples/decoding/cohy_decomposition.py @@ -548,6 +548,20 @@ # Ultimately, there are distinct advantages and disadvantages to both approaches, and # one may be more suitable than the other depending on your use case. +######################################################################################## +# Visualising spatial contributions to connectivity +# ------------------------------------------------- +# In addition to the lower-dimensional representation of connectivity, we can also +# extract information about the spatial distributions of connectivity over channels. +# This information is captured in the spatial patterns, derived from the spatial +# filters :footcite:`HaufeEtAl2014`. +# +# The patterns (and filters) can be visualised as topomaps using the +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.plot_patterns` and +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.plot_filters` methods of the +# :class:`~mne_connectivity.decoding.CoherencyDecomposition` class, discussed in more +# detail in :doc:`cohy_decomposition_plotting`. + ######################################################################################## # References # ---------- diff --git a/examples/decoding/cohy_decomposition_plotting.py b/examples/decoding/cohy_decomposition_plotting.py new file mode 100644 index 00000000..d7768963 --- /dev/null +++ b/examples/decoding/cohy_decomposition_plotting.py @@ -0,0 +1,170 @@ +""" +============================================================== +Visualising spatial contributions to multivariate connectivity +============================================================== + +This example demonstrates how the spatial filters and patterns of connectivity obtained +from the decomposition tools in the decoding module can be visualised and interpreted. +""" + +# Author: Thomas S. Binns +# License: BSD (3-clause) + +# %% + +import mne +from mne import make_fixed_length_epochs +from mne.datasets.fieldtrip_cmc import data_path + +from mne_connectivity import CoherencyDecomposition + +######################################################################################## +# Background +# ---------- +# Multivariate forms of signal analysis allow you to simultaneously consider the +# activity of multiple signals. In the case of connectivity, the interaction between +# multiple sensors can be analysed at once and the strongest components of this +# interaction captured in a lower-dimensional set of connectivity spectra. This approach +# brings not only practical benefits (e.g. easier interpretability of results from the +# dimensionality reduction), but can also offer methodological improvements (e.g. +# enhanced signal-to-noise ratio and reduced bias). +# +# Coherency-based methods are popular approaches for analysing connectivity, capturing +# correlations between signals in the frequency domain. Various coherency-based +# multivariate methods exist, including: canonical coherency (CaCoh; multivariate +# measure of coherency/coherence) :footcite:`VidaurreEtAl2019`; and maximised imaginary +# coherency (MIC; multivariate measure of the imaginary part of coherency) +# :footcite:`EwaldEtAl2012`. +# +# These methods are described in detail in the following examples: +# - comparison of coherency-based methods - :doc:`../compare_coherency_methods` +# - CaCoh - :doc:`../cacoh` +# - MIC - :doc:`../mic_mim` +# +# The CaCoh and MIC methods work by finding spatial filters that decompose the data into +# components of connectivity, and applying them to the data. Connectivity can then be +# computed on this transformed data (see :doc:`cohy_decomposition` for more +# information). +# +# However, in addition to the connectivity scores, useful insights about the data can be +# gained by visualising the topographies of the spatial filters and their corresponding +# spatial patterns. These provide important information about the spatial distributions +# of connectivity information, and represent two complementary aspects: +# +# - The filters represent how the connectivity sources are extracted from the channel +# data, akin to an inverse model. +# - The patterns represent how the channel data is formed by the connectivity sources, +# akin to a forward model. +# +# This distinction is discussed further in Haufe *et al.* (2014) +# :footcite:`HaufeEtAl2014`, but in short: **the patterns should be used to interpret +# the contribution of distinct brain regions/sensors to a given component of +# connectivity**. Accordingly, keep in mind that the filters and patterns are not a +# replacement for source reconstruction, as without this the patterns will still only +# tell you about the spatial contributions of sensors, not underlying brain regions, +# to connectivity. + +######################################################################################## +# Generating the filters and patterns +# ----------------------------------- +# We will first load some example MEG data which we will generate the spatial filters +# and patterns for, and divide it into epochs. + +# %% + +# Load example MEG data +raw = mne.io.read_raw_ctf(data_path() / "SubjectCMC.ds") +raw.pick("mag") +raw.crop(50.0, 110.0).load_data() +raw.notch_filter(50) +raw.resample(100) + +# Create epochs +epochs = make_fixed_length_epochs(raw, duration=2.0).load_data() + +######################################################################################## +# We designate the left hemisphere sensors as the seeds and the right hemisphere sensors +# as the targets. Since this is sensor-space data, we will use the MIC method to analyse +# connectivity, given its resilience to zero time-lag interactions (see +# :doc:`../compare_coherency_methods` for more information). + +# %% + +# Left hemisphere sensors +seeds = [idx for idx, ch_info in enumerate(epochs.info["chs"]) if ch_info["loc"][0] < 0] + +# Right hemisphere sensors +targets = [ + idx for idx, ch_info in enumerate(epochs.info["chs"]) if ch_info["loc"][0] > 0 +] + +# Define indices +indices = (seeds, targets) + +######################################################################################## +# To fit the filters (and in turn compute the corresponding patterns), we instantiate +# the :class:`~mne_connectivity.decoding.CoherencyDecomposition` object and call the +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.fit` method. We also define +# our connectivity frequency band of interest to be 20-30 Hz. See +# :doc:`cohy_decomposition` for more information. + +# %% + +# Instantiate decomposition object +mic = CoherencyDecomposition( + info=epochs.info, + method="mic", + indices=indices, + mode="multitaper", + fmin=20, + fmax=30, + rank=(3, 3), +) + +# Fit filters & generate patterns +mic.fit(epochs.get_data()) + +######################################################################################## +# Visualising the patterns +# ------------------------ +# Visualising the patterns as topomaps can be done using the +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.plot_patterns` method. +# +# When interpreting patterns, note that the absolute value reflects the strength of the +# contribution to connectivity, and that the sign differences can be used to visualise +# the orientation of the underlying dipole sources. The spatial patterns are **not** +# bound between :math:`[-1, 1]`. +# +# Plotting the patterns for 20-30 Hz connectivity below, we find the strongest +# connectivity between the left and right hemispheres comes from centromedial left and +# frontolateral right sensors, based on the areas with the largest absolute values. As +# these patterns come from decomposition on sensor-space data, we make no assumptions +# about the underlying brain regions involved in this connectivity. + +# %% + +# Plot patterns +mic.plot_patterns(info=epochs.info, sensors="m.", size=2) + +######################################################################################## +# Visualising the filters +# ----------------------- +# We can also visualise the filters as topomaps using the +# :meth:`~mne_connectivity.decoding.CoherencyDecomposition.plot_filters` method. +# +# Here we see that the filters show a similar topography to the patterns. However, this +# is not always the case, and you should never confuse the information represented by +# the filters (i.e. an inverse model) and patterns (i.e. a forward model), which can +# lead to very incorrect interpretations of the data :footcite:`HaufeEtAl2014`. + +# %% + +# Plot filters +mic.plot_filters(info=epochs.info, sensors="m.", size=2) + +######################################################################################## +# References +# ---------- +# .. footbibliography:: + +# %% From 0a3f7ace1cf7df91ac871da5cfd4a3232aa76e57 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Thu, 20 Jun 2024 12:20:59 +0200 Subject: [PATCH 34/38] Update epochs_multivar formatting and docs --- mne_connectivity/spectral/epochs_multivariate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 3881e580..b445bc73 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -115,8 +115,8 @@ def __init__( self.n_cons = n_cons self.n_freqs = n_freqs self.n_times = n_times - self.n_jobs = n_jobs self.store_filters = store_filters + self.n_jobs = n_jobs # include time dimension, even when unused for indexing flexibility if n_times == 0: @@ -193,7 +193,7 @@ class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): """Base estimator for multivariate coherency methods. See: - - Imaginary part of coherency, i.e. multivariate imaginary part of + - Imaginary part of coherency, i.e. maximised imaginary part of coherency (MIC) and multivariate interaction measure (MIM): Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 - Coherency/coherence, i.e. canonical coherency (CaCoh): Vidaurre et al. From ad5db634415e6b0da2afe39e28cbe0103b9e0b01 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 26 Jun 2024 18:50:42 +0200 Subject: [PATCH 35/38] Update docstring --- mne_connectivity/utils/docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/utils/docs.py b/mne_connectivity/utils/docs.py index ed101b1c..c90100a4 100644 --- a/mne_connectivity/utils/docs.py +++ b/mne_connectivity/utils/docs.py @@ -444,7 +444,7 @@ docdict["name_format_topomap"] = r""" name_format : str | None (default None) - The string format for axes titles. If `None`, uses f"{method}%%01d", i.e. the + The string format for axes titles. If `None`, uses ``f"{method}%%01d"``, i.e. the method name followed by the component number. """ From ee1c109df2449524b9ea9e6feb8471251cd63d8d Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 26 Jun 2024 18:52:02 +0200 Subject: [PATCH 36/38] Update plotting --- mne_connectivity/decoding/decomposition.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mne_connectivity/decoding/decomposition.py b/mne_connectivity/decoding/decomposition.py index 75097347..1263cd86 100644 --- a/mne_connectivity/decoding/decomposition.py +++ b/mne_connectivity/decoding/decomposition.py @@ -4,7 +4,6 @@ # # License: BSD (3-clause) -from copy import deepcopy from typing import Optional import numpy as np @@ -16,6 +15,7 @@ from mne.fixes import BaseEstimator from mne.time_frequency import csd_array_fourier, csd_array_morlet, csd_array_multitaper from mne.utils import _check_option, _validate_type +from mne.viz.utils import plt_show from ..spectral.epochs_multivariate import _CaCohEst, _check_rank_input, _MICEst from ..utils import _check_multivariate_indices, fill_doc @@ -727,8 +727,7 @@ def _plot_filters_patterns( figs = [] for group_idx, group_name in zip([0, 1], ["Seeds", "Targets"]): # create info for seeds/targets - group_info = deepcopy(info) - group_info = pick_info(group_info, self.indices[group_idx], copy=False) + group_info = pick_info(info, self.indices[group_idx]) with group_info._unlock(): group_info["sfreq"] = 1.0 # 1 component per time point # create Evoked object @@ -768,7 +767,6 @@ def _plot_filters_patterns( ) ) figs[-1].suptitle(group_name) # differentiate seeds from targets - if show: - figs[-1].show() + plt_show(show=show, fig=figs[-1]) return figs From 7709a329172c2e2e9b2f2048134ecacebe436ad3 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Wed, 26 Jun 2024 18:52:12 +0200 Subject: [PATCH 37/38] Update plotting example --- examples/decoding/cohy_decomposition_plotting.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/decoding/cohy_decomposition_plotting.py b/examples/decoding/cohy_decomposition_plotting.py index d7768963..75f4f84a 100644 --- a/examples/decoding/cohy_decomposition_plotting.py +++ b/examples/decoding/cohy_decomposition_plotting.py @@ -67,8 +67,9 @@ ######################################################################################## # Generating the filters and patterns # ----------------------------------- -# We will first load some example MEG data which we will generate the spatial filters -# and patterns for, and divide it into epochs. +# We will first load some example MEG data collected during a hand movement task, which +# we will generate the spatial filters and patterns for. We divide the data into +# continuous epochs. # %% From aed36758da7ff646df1914e5b3b7a0055632d41a Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 2 Jul 2024 12:16:41 +0200 Subject: [PATCH 38/38] Add link to ft example --- examples/decoding/cohy_decomposition_plotting.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/decoding/cohy_decomposition_plotting.py b/examples/decoding/cohy_decomposition_plotting.py index 75f4f84a..780495b4 100644 --- a/examples/decoding/cohy_decomposition_plotting.py +++ b/examples/decoding/cohy_decomposition_plotting.py @@ -68,8 +68,9 @@ # Generating the filters and patterns # ----------------------------------- # We will first load some example MEG data collected during a hand movement task, which -# we will generate the spatial filters and patterns for. We divide the data into -# continuous epochs. +# we will generate the spatial filters and patterns for (see +# `here `_ for more information on +# the data). We divide the data into continuous epochs. # %%