Skip to content

Commit

Permalink
refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Cedric Kulbach committed Feb 13, 2023
1 parent ed6d452 commit a1fd132
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions deep_river/anomaly/rolling_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@


class _TestLSTMAutoencoder(nn.Module):
def __init__(
self, n_features, hidden_size=10, n_layers=1, batch_first=False
):
def __init__(self, n_features, hidden_size=30, n_layers=1,
batch_first=False):
super().__init__()
self.n_features = n_features
self.hidden_size = hidden_size
Expand All @@ -39,9 +38,7 @@ def forward(self, x):
x_flipped = torch.flip(x[1:], dims=[self.time_axis])
input = torch.cat((h, x_flipped), dim=self.time_axis)
x_hat, _ = self.decoder(input)
x_hat = torch.flip(x_hat, dims=[self.time_axis])

return x_hat
return torch.flip(x_hat, dims=[self.time_axis])


class RollingAutoencoder(RollingDeepEstimator, anomaly.base.AnomalyDetector):
Expand Down

0 comments on commit a1fd132

Please sign in to comment.