Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a filter function for raw data #41

Merged
merged 4 commits into from
Oct 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 101 additions & 39 deletions braingeneers/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from braingeneers.utils import s3wrangler
from braingeneers.utils.common_utils import get_basepath


__all__ = [
"DCCResult",
"read_phy_files",
Expand All @@ -41,6 +40,7 @@

logger = getLogger("braingeneers.analysis")


@dataclass
class NeuronAttributes:
cluster_id: int
Expand Down Expand Up @@ -113,7 +113,6 @@ def load_spike_data(uuid, experiment=None, basepath=None, full_path=None, fs=200
logger.info('prefix: %s', prefix)
path = posixpath.join(basepath, prefix)


if full_path is not None:
experiment = full_path.split('/')[-1].split('.')[0]
logger.info('Using full path, experiment: %s', experiment)
Expand All @@ -139,7 +138,6 @@ def load_spike_data(uuid, experiment=None, basepath=None, full_path=None, fs=200
# If path is a local path, check locally
file_list = glob.glob(path + '*.zip')


zip_files = [file for file in file_list if file.endswith('.zip')]

if not zip_files:
Expand All @@ -149,8 +147,6 @@ def load_spike_data(uuid, experiment=None, basepath=None, full_path=None, fs=200

path = zip_files[0]



with smart_open.open(path, 'rb') as f0:
f = io.BytesIO(f0.read())
logger.debug('Opening zip file...')
Expand Down Expand Up @@ -221,14 +217,13 @@ def load_spike_data(uuid, experiment=None, basepath=None, full_path=None, fs=200

logger.debug('Creating spike data...')

metadata = {"experiment":experiment}
metadata = {"experiment": experiment}
spike_data = SpikeData(cluster_agg["spikeTimes"].to_list(), neuron_attributes=neuron_attributes, metadata=metadata)

logger.debug('Done.')
return spike_data



@deprecated('Prefer load_spike_data()', version='0.1.13')
def read_phy_files(path: str, fs=20000.0):
"""
Expand Down Expand Up @@ -319,11 +314,11 @@ def read_phy_files(path: str, fs=20000.0):
config_dict = dict(zip(channels, positions))
neuron_data = {0: neuron_dict}
metadata = {0: config_dict}
spikedata = SpikeData(list(cluster_agg["spikeTimes"]), neuron_data=neuron_data, metadata=metadata, neuron_attributes=neuron_attributes)
spikedata = SpikeData(list(cluster_agg["spikeTimes"]), neuron_data=neuron_data, metadata=metadata,
neuron_attributes=neuron_attributes)
return spikedata



class SpikeData:
"""
Class for handling and manipulating neuronal spike data.
Expand Down Expand Up @@ -493,7 +488,6 @@ def from_thresholding(data, fs_Hz=20e3, threshold_sigma=5.0,
return SpikeData.from_raster(raster, 1e3 / fs_Hz,
raw_data=data, raw_time=fs_Hz / 1e3)


def __init__(self, train, *, N=None, length=None,
neuron_attributes=[], neuron_data={}, metadata={},
raw_data=None, raw_time=None):
Expand Down Expand Up @@ -640,8 +634,8 @@ def cond(i):
return self.neuron_data[by][i] in units

train = [ts for i, ts in enumerate(self.train) if cond(i)]
neuron_data = {k: [v for i,v in enumerate(vs) if cond(i)]
for k,vs in self.neuron_data.items()}
neuron_data = {k: [v for i, v in enumerate(vs) if cond(i)]
for k, vs in self.neuron_data.items()}

neuron_attributes = []
if len(self.neuron_attributes) >= len(units):
Expand Down Expand Up @@ -706,15 +700,14 @@ def __getitem__(self, key):
else:
return self.subset(key)


def append(self, spikeData, offset=0):
'''Appends a spikeData object to the current object. These must have
the same number of neurons.

:param: spikeData: spikeData object to append to the current object
'''
assert self.N == spikeData.N, 'Number of neurons must be the same'
train = ([np.hstack([tr1, tr2 + self.length + offset]) for tr1, tr2 in zip(self.train,spikeData.train)])
train = ([np.hstack([tr1, tr2 + self.length + offset]) for tr1, tr2 in zip(self.train, spikeData.train)])
raw_data = np.concatenate((self.raw_data, spikeData.raw_data), axis=1)
raw_time = np.concatenate((self.raw_time, spikeData.raw_time))
length = self.length + spikeData.length + offset
Expand All @@ -726,8 +719,6 @@ def append(self, spikeData, offset=0):
neuron_data=self.neuron_data,
raw_time=raw_time, raw_data=raw_data)



def sparse_raster(self, bin_size=20):
'''
Bin all spike times and create a sparse array where entry
Expand Down Expand Up @@ -846,7 +837,6 @@ def concatenate_spike_data(self, sd):
self.metadata.update(sd.metadata)
self.neuron_attributes += sd.neuron_attributes


def spike_time_tilings(self, delt=20):
"""
Compute the full spike time tiling coefficient matrix.
Expand All @@ -862,7 +852,6 @@ def spike_time_tilings(self, delt=20):
)
return ret


def spike_time_tiling(self, i, j, delt=20):
'''
Calculate the spike time tiling coefficient between two units within
Expand Down Expand Up @@ -944,7 +933,6 @@ def deviation_from_criticality(self, quantile=0.35, bin_size=40,
# Return the DCC value and significance.
return DCCResult(dcc=dcc, p_size=p_size, p_duration=p_dur)


def latencies(self, times, window_ms=100):
'''
Given a sorted list of times, compute the latencies from that time to
Expand All @@ -969,7 +957,7 @@ def latencies(self, times, window_ms=100):
abs_diff_ind = np.argmin(np.abs(train - time))

# Calculate the actual latency
latency = np.array(train)-time
latency = np.array(train) - time
latency = latency[abs_diff_ind]

abs_diff = np.abs(latency)
Expand Down Expand Up @@ -1003,6 +991,45 @@ def randomized(self, bin_size_ms=1.0, seed=None):
neuron_attributes=self.neuron_attributes
)

def population_firing_rate(self, bin_size=10, w=5, average=False):
"""
Population firing rate of all units in the SpikeData object.
"""
bins, pop_rate = population_firing_rate(self.train, self.length, bin_size, w, average)
return bins, pop_rate


def population_firing_rate(trains, rec_length=None, bin_size=10, w=5, average=False):
"""
Calculate population firing rate for given spike trains.
:param trains: a list of spike trains. Can take only one unit
:param rec_length: length of the recording.
If None, the maximum spike time is used.
:param bin_size: binning width
:param w: kernel width for smoothing
:param average: If True, the result is averaged by the number of units.
Otherwise, the result is return as it is.
:return: An array of the bins and an array of the frequency
for the given units' spiking activity
"""
if isinstance(trains, (list, np.ndarray)) \
and not isinstance(trains[0], (list, np.ndarray)):
N = 1
else:
N = len(trains)

trains = np.hstack(trains)
if rec_length is None:
rec_length = np.max(trains)

bin_num = int(rec_length // bin_size) + 1
bins = np.linspace(0, rec_length, bin_num)
fr = np.histogram(trains, bins)[0] / bin_size
fr_pop = np.convolve(fr, np.ones(w), 'same') / w
if average:
fr_pop /= N
return bins, fr_pop


def spike_time_tiling(tA, tB, delt=20, length=None):
"""
Expand Down Expand Up @@ -1057,7 +1084,6 @@ def best_effort_sample(counts, M, rng=np.random):
return ret



def randomize_raster(raster, seed=None):
"""
Randomize a raster by taking out all the spikes in each time bin and
Expand Down Expand Up @@ -1101,19 +1127,17 @@ def filter(raw_data, fs_Hz=20000, filter_order=3, filter_lo_Hz=300,
:return: filtered data
'''


time_step_size = int(time_step_size_s * fs_Hz)
data = np.zeros_like(raw_data)


# Get filter params
b, a = signal.butter(fs=fs_Hz, btype='bandpass',
N=filter_order, Wn=[filter_lo_Hz, filter_hi_Hz])

if zi is None:
# Filter initial state
zi = signal.lfilter_zi(b, a)
zi = np.vstack([zi*np.mean(raw_data[ch,:5])
zi = np.vstack([zi * np.mean(raw_data[ch, :5])
for ch in range(raw_data.shape[0])])

# Step through the data in chunks and filter it
Expand All @@ -1125,9 +1149,9 @@ def filter(raw_data, fs_Hz=20000, filter_order=3, filter_lo_Hz=300,
for t_start in range(0, raw_data.shape[1], time_step_size):
t_end = min(t_start + time_step_size, raw_data.shape[1])

data[ch_start:ch_end, t_start:t_end], zi[ch_start:ch_end,:] = signal.lfilter(
data[ch_start:ch_end, t_start:t_end], zi[ch_start:ch_end, :] = signal.lfilter(
b, a, raw_data[ch_start:ch_end, t_start:t_end],
axis=1, zi=zi[ch_start:ch_end,:])
axis=1, zi=zi[ch_start:ch_end, :])

return data if not return_zi else (data, zi)

Expand All @@ -1144,8 +1168,8 @@ def _resampled_isi(spikes, times, sigma_ms):
elif len(spikes) == 1:
return np.ones_like(times) / spikes[0]
else:
x = 0.5*(spikes[:-1] + spikes[1:])
y = 1/np.diff(spikes)
x = 0.5 * (spikes[:-1] + spikes[1:])
y = 1 / np.diff(spikes)
fr = np.interp(times, x, y)
if len(np.atleast_1d(fr)) < 2:
return fr
Expand Down Expand Up @@ -1238,7 +1262,7 @@ def fano_factors(raster):
# This is the variance/mean ratio computed in a sparse-friendly
# way. This algorithm is numerically unstable in general, but
# should only be a problem if your bin size is way too big.
return moment/mean - mean
return moment / mean - mean

else:
mean = np.asarray(raster).mean(1)
Expand All @@ -1257,7 +1281,7 @@ def _sttc_ta(tA, delt, tmax):
return 0

base = min(delt, tA[0]) + min(delt, tmax - tA[-1])
return base + np.minimum(np.diff(tA), 2*delt).sum()
return base + np.minimum(np.diff(tA), 2 * delt).sum()


def _sttc_na(tA, tB, delt):
Expand All @@ -1275,9 +1299,9 @@ def _sttc_na(tA, tB, delt):

# Clip to ensure legal indexing, then check the spike at that
# index and its predecessor to see which is closer.
np.clip(iB, 1, len(tB)-1, out=iB)
np.clip(iB, 1, len(tB) - 1, out=iB)
dt_left = np.abs(tB[iB] - tA)
dt_right = np.abs(tB[iB-1] - tA)
dt_right = np.abs(tB[iB - 1] - tA)

# Return how many of those spikes are actually within delt.
return (np.minimum(dt_left, dt_right) <= delt).sum()
Expand All @@ -1295,8 +1319,8 @@ def pearson(spikes):

Exy = (spikes @ spikes.T) / spikes.shape[1]
Ex = spikes.mean(axis=1)
Ex2 = (spikes**2).mean(axis=1)
σx = np.sqrt(Ex2 - Ex**2)
Ex2 = (spikes ** 2).mean(axis=1)
σx = np.sqrt(Ex2 - Ex ** 2)

# Some cells won't fire in the whole observation window. To get their
# correlation coefficients to zero, give them infinite σ.
Expand All @@ -1319,7 +1343,7 @@ def cumulative_moving_average(hist):
cma = 0
cma_list = []
for i in range(len(h)):
cma = (cma * i + h[i]) / (i+1)
cma = (cma * i + h[i]) / (i + 1)
cma_list.append(cma)
ret.append(cma_list)
return ret
Expand All @@ -1336,17 +1360,55 @@ def burst_detection(spike_times, burst_threshold, spike_num_thr=3):
'''
spike_num_burst = 1
spike_num_list = []
for i in range(len(spike_times)-1):
if spike_times[i+1] - spike_times[i] <= burst_threshold:
for i in range(len(spike_times) - 1):
if spike_times[i + 1] - spike_times[i] <= burst_threshold:
spike_num_burst += 1
else:
if spike_num_burst >= spike_num_thr:
spike_num_list.append([i-spike_num_burst+1, spike_num_burst])
spike_num_list.append([i - spike_num_burst + 1, spike_num_burst])
spike_num_burst = 1
else:
spike_num_burst = 1
burst_set = []
for loc in spike_num_list:
for i in range(loc[1]):
burst_set.append(spike_times[loc[0]+i])
burst_set.append(spike_times[loc[0] + i])
return spike_num_list, burst_set


def butter_filter(data, lowcut=None, highcut=None, fs=20000.0, order=5):
"""
A digital butterworth filter. Type is based on input value.
Inputs:
data: array_like data to be filtered
lowcut: low cutoff frequency. If None or 0, highcut must be a number.
Filter is lowpass.
highcut: high cutoff frequency. If None, lowpass must be a non-zero number.
Filter is highpass.
If lowcut and highcut are both give, this filter is bandpass.
In this case, lowcut must be smaller than highcut.
fs: sample rate
order: order of the filter
Return:
The filtered output with the same shape as data
"""

assert (lowcut not in [None, 0]) or (highcut != None), \
"Need at least a low cutoff (lowcut) or high cutoff (highcut) frequency!"
if (lowcut != None) and (highcut != None):
assert lowcut < highcut, "lowcut must be smaller than highcut"

if lowcut == None or lowcut == 0:
filter_type = 'lowpass'
Wn = highcut / fs * 2
elif highcut == None:
filter_type = 'highpass'
Wn = lowcut / fs * 2
else:
filter_type = "bandpass"
band = [lowcut, highcut]
Wn = [e / fs * 2 for e in band]

filter_coeff = signal.iirfilter(order, Wn, analog=False, btype=filter_type, output='sos')
filtered_traces = signal.sosfiltfilt(filter_coeff, data)
return filtered_traces