Skip to content

Commit

Permalink
reblackened
Browse files Browse the repository at this point in the history
  • Loading branch information
robintibor committed May 6, 2019
1 parent 22c9af2 commit c930a29
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 42 deletions.
7 changes: 5 additions & 2 deletions braindecode/datasets/lazy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ class LazyDataset(ABC):
""" Class implementing an abstract lazy data set. Custom lazy data sets
have to override file_paths, X and y as well as the load_lazy function to
load trials or crops. """

def __init__(self):
self.file_paths = "Not implemented: a list of all file paths"
self.X = ("Not implemented: a list of empty ndarrays with number of "
"samples as second dimension")
self.X = (
"Not implemented: a list of empty ndarrays with number of "
"samples as second dimension"
)
self.y = "Not implemented: a list of all targets"

@abstractmethod
Expand Down
67 changes: 46 additions & 21 deletions braindecode/datautil/lazy_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import torch as th
import numpy as np

from braindecode.datautil.iterators import _compute_start_stop_block_inds, \
get_balanced_batches
from braindecode.datautil.iterators import (
_compute_start_stop_block_inds,
get_balanced_batches,
)


def custom_collate(batch, rng_state=None):
Expand All @@ -17,7 +19,7 @@ def custom_collate(batch, rng_state=None):
want to decrease using lazy loading
"""
elem_type = type(batch[0])
if elem_type.__module__ == 'numpy':
if elem_type.__module__ == "numpy":
if rng_state is not None:
th.random.set_rng_state(rng_state)
return np.stack([b for b in batch], 0)
Expand Down Expand Up @@ -54,10 +56,18 @@ class LazyCropsFromTrialsIterator(object):
Checking validity of predictions and trial lengths. Disable to decrease
runtime.
"""
def __init__(self, input_time_length, n_preds_per_input, batch_size,
seed=328774, num_workers=0, collate_fn=custom_collate,
check_preds_smaller_trial_len=True,
reset_rng_after_each_batch=False):

def __init__(
self,
input_time_length,
n_preds_per_input,
batch_size,
seed=328774,
num_workers=0,
collate_fn=custom_collate,
check_preds_smaller_trial_len=True,
reset_rng_after_each_batch=False,
):
self.batch_size = batch_size
self.seed = seed
self.rng = RandomState(self.seed)
Expand All @@ -82,11 +92,14 @@ def get_batches(self, dataset, shuffle):
collate_fn = partial(self.collate_fn, rng_state=random_state)
else:
collate_fn = partial(self.collate_fn, rng_state=None)
batch_indeces = self._get_batch_indeces(dataset=dataset,
shuffle=shuffle)
data_loader = DataLoader(dataset=dataset, batch_sampler=batch_indeces,
num_workers=self.num_workers,
pin_memory=False, collate_fn=collate_fn)
batch_indeces = self._get_batch_indeces(dataset=dataset, shuffle=shuffle)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_indeces,
num_workers=self.num_workers,
pin_memory=False,
collate_fn=collate_fn,
)
return data_loader

def _get_batch_indeces(self, dataset, shuffle):
Expand All @@ -101,23 +114,35 @@ def _get_batch_indeces(self, dataset, shuffle):
for i_trial, input_len in enumerate(input_lens):
assert input_len >= self.input_time_length, (
"Input length {:d} of trial {:d} is smaller than the "
"input time length {:d}".format(input_len, i_trial,
self.input_time_length))
"input time length {:d}".format(
input_len, i_trial, self.input_time_length
)
)

start_stop_blocks_per_trial = _compute_start_stop_block_inds(
i_trial_starts, i_trial_stops, self.input_time_length,
i_trial_starts,
i_trial_stops,
self.input_time_length,
self.n_preds_per_input,
check_preds_smaller_trial_len=self.check_preds_smaller_trial_len)
check_preds_smaller_trial_len=self.check_preds_smaller_trial_len,
)
for i_trial, trial_blocks in enumerate(start_stop_blocks_per_trial):
assert trial_blocks[0][0] == 0
assert trial_blocks[-1][1] == i_trial_stops[i_trial]

i_trial_start_stop_block = np.array([
(i_trial, start, stop) for i_trial, block in
enumerate(start_stop_blocks_per_trial) for start, stop in block])
i_trial_start_stop_block = np.array(
[
(i_trial, start, stop)
for i_trial, block in enumerate(start_stop_blocks_per_trial)
for start, stop in block
]
)

batches = get_balanced_batches(
n_trials=len(i_trial_start_stop_block), rng=self.rng,
shuffle=shuffle, batch_size=self.batch_size)
n_trials=len(i_trial_start_stop_block),
rng=self.rng,
shuffle=shuffle,
batch_size=self.batch_size,
)

return [i_trial_start_stop_block[batch_ind] for batch_ind in batches]
37 changes: 18 additions & 19 deletions braindecode/experiments/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,13 @@ def run(self):
log.info("Run until second stop...")
loss_to_reach = float(self.epochs_df["train_loss"].iloc[-1])
self.run_until_second_stop()
if ((float(self.epochs_df['valid_loss'].iloc[-1]) > loss_to_reach)
and self.reset_after_second_run):
if (
float(self.epochs_df["valid_loss"].iloc[-1]) > loss_to_reach
) and self.reset_after_second_run:
# if no valid loss was found below the best train loss on 1st
# run, reset model to the epoch with lowest valid_misclass
log.info(
"Resetting to best epoch {:d}".format(
self.rememberer.best_epoch
)
"Resetting to best epoch {:d}".format(self.rememberer.best_epoch)
)
self.rememberer.reset_to_best_model(
self.epochs_df, self.model, self.optimizer
Expand Down Expand Up @@ -426,8 +425,7 @@ def monitor_epoch(self, datasets):
# iterating through traditional iterators is cheap, since
# nothing is loaded, recreate generator afterwards
n_batches = sum(1 for i in batch_generator)
batch_generator = self.iterator.get_batches(dataset,
shuffle=False)
batch_generator = self.iterator.get_batches(dataset, shuffle=False)
all_preds, all_targets = None, None
all_losses, all_batch_sizes = [], []
for inputs, targets in batch_generator:
Expand All @@ -439,12 +437,13 @@ def monitor_epoch(self, datasets):
# first batch size is largest
max_size, n_classes, n_preds_per_input = preds.shape
# pre-allocate memory for all predictions and targets
all_preds = np.nan * np.ones(
all_preds = np.nan * np.ones(
(n_batches * max_size, n_classes, n_preds_per_input),
dtype=np.float32)
all_preds[:len(preds)] = preds
all_targets = np.nan * np.ones((n_batches * max_size))
all_targets[:len(targets)] = targets
dtype=np.float32,
)
all_preds[: len(preds)] = preds
all_targets = np.nan * np.ones((n_batches * max_size))
all_targets[: len(targets)] = targets
else:
start_i = sum(all_batch_sizes[:-1])
stop_i = sum(all_batch_sizes)
Expand All @@ -456,15 +455,15 @@ def monitor_epoch(self, datasets):
all_batch_sizes = sum(all_batch_sizes)
# remove nan rows in case of unequal batch sizes
if unequal_batches:
assert np.sum(np.isnan(all_preds[:all_batch_sizes - 1])) == 0
assert np.sum(np.isnan(all_preds[: all_batch_sizes - 1])) == 0
assert np.sum(np.isnan(all_preds[all_batch_sizes:])) > 0
range_to_delete = range(all_batch_sizes, len(all_preds))
all_preds = np.delete(all_preds, range_to_delete, axis=0)
all_targets = np.delete(all_targets, range_to_delete, axis=0)
assert np.sum(np.isnan(all_preds)) == 0, (
"There are still nans in predictions")
assert np.sum(np.isnan(all_targets)) == 0, (
"There are still nans in targets")
assert (
np.sum(np.isnan(all_preds)) == 0
), "There are still nans in predictions"
assert np.sum(np.isnan(all_targets)) == 0, "There are still nans in targets"
# add empty dimension
# monitors expect n_batches x ...
all_preds = all_preds[np.newaxis, :]
Expand All @@ -488,8 +487,8 @@ def monitor_epoch(self, datasets):
row_dict.update(result_dicts_per_monitor[m])
self.epochs_df = self.epochs_df.append(row_dict, ignore_index=True)
assert set(self.epochs_df.columns) == set(row_dict.keys()), (
"Columns of dataframe: {:s}\n and keys of dict {:s} not same")\
.format(str(set(self.epochs_df.columns)), str(set(row_dict.keys())))
"Columns of dataframe: {:s}\n and keys of dict {:s} not same"
).format(str(set(self.epochs_df.columns)), str(set(row_dict.keys())))
self.epochs_df = self.epochs_df[list(row_dict.keys())]

def log_epoch(self):
Expand Down

0 comments on commit c930a29

Please sign in to comment.