Skip to content

Commit

Permalink
now using convolution for average pooling in deepnet, for possiblitiy…
Browse files Browse the repository at this point in the history
… to use dilation
  • Loading branch information
robintibor committed Sep 18, 2017
1 parent 4ed5d53 commit 39d09cb
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
4 changes: 2 additions & 2 deletions braindecode/models/deep4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn
from torch.nn import init
from torch.nn.functional import elu
from braindecode.torch_ext.modules import Expression
from braindecode.torch_ext.modules import Expression, AvgPool2dWithConv
from braindecode.torch_ext.functions import identity
from braindecode.torch_ext.util import np_to_var

Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(self, in_chans,
del self.self

def create_network(self):
pool_class_dict = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)
pool_class_dict = dict(max=nn.MaxPool2d, mean=AvgPool2dWithConv)
first_pool_class = pool_class_dict[self.first_pool_mode]
later_pool_class = pool_class_dict[self.later_pool_mode]
model = nn.Sequential()
Expand Down
52 changes: 51 additions & 1 deletion braindecode/torch_ext/modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import numpy as np
import torch
import torch.nn.functional as F

from braindecode.torch_ext.util import np_to_var


class Expression(torch.nn.Module):
Expand Down Expand Up @@ -27,4 +31,50 @@ def __repr__(self):
else:
expression_str = self.expression_fn.__name__
return (self.__class__.__name__ + '(' +
'expression=' + str(expression_str) + ')')
'expression=' + str(expression_str) + ')')


class AvgPool2dWithConv(torch.nn.Module):
"""
Compute average pooling using a convolution, to have the dilation parameter.
Parameters
----------
kernel_size: (int,int)
Size of the pooling region.
stride: (int,int)
Stride of the pooling operation.
dilation: int or (int,int)
Dilation applied to the pooling filter.
"""
def __init__(self, kernel_size, stride, dilation=1):
super(AvgPool2dWithConv, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.weights = None

def forward(self, x):
# Create weights for the convolution on demand:
# size or type of x changed...
in_channels = x.size()[1]
weight_shape = (in_channels, 1,
self.kernel_size[0], self.kernel_size[1])
if self.weights is None or (
(tuple(self.weights.size()) != tuple(weight_shape)) or (
self.weights.is_cuda != x.is_cuda
) or (
self.weights.data.type() != x.data.type()
)):
n_pool = np.prod(self.kernel_size)
weights = np_to_var(
np.ones(weight_shape, dtype=np.float32) / float(n_pool))
weights = weights.type_as(x)
if x.is_cuda:
weights = weights.cuda()
self.weights = weights

pooled = F.conv2d(x, self.weights, bias=None, stride=self.stride,
dilation=self.dilation,
groups=in_channels,)
return pooled
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Versions should comply with PEP440. For a discussion on single-sourcing
# the version across setup.py and the project code, see
# http://packaging.python.org/en/latest/tutorial.html#version
version='0.1.7', # TODO: read from __init__.py?
version='0.1.8', # TODO: read from __init__.py?

description='A deep learning toolbox to decode raw time-domain EEG.',
long_description=long_description,
Expand Down

0 comments on commit 39d09cb

Please sign in to comment.