Skip to content

Commit

Permalink
feat(sttc): more efficient calculation
Browse files Browse the repository at this point in the history
Calculating the spike time tiling matrix using `spike_time_tilings()`
used to duplicate the work of calculating the constants TA and TB for
each of the units. It's not as useful as I expected, as it only seems
to be about 10% faster, but I've already done and tested it.
  • Loading branch information
atspaeth committed Sep 18, 2023
1 parent 0a681d9 commit 3af53c5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
16 changes: 13 additions & 3 deletions braingeneers/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,15 +833,21 @@ def concatenate_spike_data(self, sd):


def spike_time_tilings(self, delt=20):
'''
"""
Compute the full spike time tiling coefficient matrix.
'''
"""
T = self.length
ts = [_sttc_ta(ts, delt, T) / T for ts in self.train]

ret = np.diag(np.ones(self.N))
for i in range(self.N):
for j in range(i + 1, self.N):
ret[i, j] = ret[j, i] = self.spike_time_tiling(i, j, delt)
ret[i, j] = ret[j, i] = _spike_time_tiling(
self.train[i], self.train[j], ts[i], ts[j], delt
)
return ret


def spike_time_tiling(self, i, j, delt=20):
'''
Calculate the spike time tiling coefficient between two units within
Expand Down Expand Up @@ -1010,7 +1016,11 @@ def spike_time_tiling(tA, tB, delt=20, length=None):

TA = _sttc_ta(tA, delt, length) / length
TB = _sttc_ta(tB, delt, length) / length
return _spike_time_tiling(tA, tB, TA, TB, delt)


def _spike_time_tiling(tA, tB, TA, TB, delt):
"Internal helper method for the second half of STTC calculation."
PA = _sttc_na(tA, tB, delt) / len(tA)
PB = _sttc_na(tB, tA, delt) / len(tB)

Expand Down
1 change: 1 addition & 0 deletions braingeneers/analysis/analysis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def test_spike_time_tiling_coefficient(self):
self.assertEqual(sttc[0, 1], sttc[1, 0])
self.assertEqual(sttc[0, 0], 1.0)
self.assertEqual(sttc[1, 1], 1.0)
self.assertEqual(sttc[0, 1], foo.spike_time_tiling(0, 1, 1))

# Default arguments, inferred value of tmax.
tmax = max(foo.train[0].ptp(), foo.train[1].ptp())
Expand Down

0 comments on commit 3af53c5

Please sign in to comment.