Skip to content

Commit

Permalink
Merge pull request #83 from online-ml/35-add-network-factories-with-b…
Browse files Browse the repository at this point in the history
…asic-pre-implemented-network-architectures

35 add network factories with basic pre implemented network architectures
  • Loading branch information
kulbachcedric committed Feb 15, 2023
2 parents 8f44b6d + fadbbc2 commit 36b9b47
Show file tree
Hide file tree
Showing 9 changed files with 495 additions and 18 deletions.
4 changes: 4 additions & 0 deletions deep_river/anomaly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
AnomalyStandardScaler,
)

"""
This module contains the anomaly detection algorithms for the
deep_river package.
"""
__all__ = [
"Autoencoder",
"RollingAutoencoder",
Expand Down
15 changes: 3 additions & 12 deletions deep_river/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import collections
import inspect
from typing import Any, Callable, Deque, Dict, Optional, Type, Union, cast
from typing import Any, Callable, Deque, Optional, Type, Union, cast

import torch
from river import base
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(

@abc.abstractmethod
def learn_one(
self, x: dict, y: Optional[Union[Any, Dict[Any, Any]]], **kwargs
self, x: dict, y: Optional[Any], **kwargs
) -> "DeepEstimator":
"""
Performs one step of training with a single example.
Expand Down Expand Up @@ -170,16 +170,7 @@ def clone(self, new_params: dict = {}, include_attributes=False):
"""
new_params = new_params or {}
new_params.update(self.kwargs)
new_params.update(
{
"seed": self.seed,
"device": self.device,
"lr": self.lr,
"loss_fn": self.loss_fn,
"optimizer_fn": self.optimizer_fn,
"module": self.module_cls,
}
)
new_params.update(self._get_params())

clone = self.__class__(**new_params)
if include_attributes:
Expand Down
9 changes: 9 additions & 0 deletions deep_river/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from deep_river.classification.classifier import Classifier
from deep_river.classification.rolling_classifier import RollingClassifier
from deep_river.classification.zoo import (
LogisticRegression,
MultiLayerPerceptron,
)

"""
This module contains the classifiers for the deep_river package.
"""
__all__ = [
"Classifier",
"RollingClassifier",
"MultiLayerPerceptron",
"LogisticRegression",
]
4 changes: 2 additions & 2 deletions deep_river/classification/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class Classifier(DeepEstimator, base.MiniBatchClassifier):
Device to run the wrapped model on. Can be "cpu" or "cuda".
seed
Random seed to be used for training the wrapped model.
**net_params
**kwargs
Parameters to be passed to the `build_fn` function aside from
`n_features`.
Expand Down Expand Up @@ -136,9 +136,9 @@ def __init__(
**kwargs,
):
super().__init__(
module=module,
loss_fn=loss_fn,
optimizer_fn=optimizer_fn,
module=module,
device=device,
lr=lr,
seed=seed,
Expand Down
250 changes: 250 additions & 0 deletions deep_river/classification/zoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
from typing import Callable, Union

from torch import nn

from deep_river.classification import Classifier


class LogisticRegression(Classifier):
"""
This class implements a logistic regression model in PyTorch.
Parameters
----------
loss_fn
Loss function to be used for training the wrapped model. Can be a
loss function provided by `torch.nn.functional` or one of the
following: 'mse', 'l1', 'cross_entropy',
'binary_cross_entropy_with_logits', 'binary_crossentropy',
'smooth_l1', 'kl_div'.
optimizer_fn
Optimizer to be used for training the wrapped model.
Can be an optimizer class provided by `torch.optim` or one of the
following: "adam", "adam_w", "sgd", "rmsprop", "lbfgs".
lr
Learning rate of the optimizer.
output_is_logit
Whether the module produces logits as output. If true, either
softmax or sigmoid is applied to the outputs when predicting.
is_class_incremental
Whether the classifier should adapt to the appearance of
previously unobserved classes by adding an unit to the output
layer of the network. This works only if the last trainable
layer is an nn.Linear layer. Note also, that output activation
functions can not be adapted, meaning that a binary classifier
with a sigmoid output can not be altered to perform multi-class
predictions.
device
Device to run the wrapped model on. Can be "cpu" or "cuda".
seed
Random seed to be used for training the wrapped model.
**kwargs
Parameters to be passed to the `build_fn` function aside from
`n_features`.
Examples
--------
>>> from deep_river.classification import LogisticRegression
>>> from river import metrics, preprocessing, compose, datasets
>>> from torch import nn, manual_seed
>>> _ = manual_seed(42)
>>> model_pipeline = compose.Pipeline(
... preprocessing.StandardScaler(),
... LogisticRegression()
... )
>>> dataset = datasets.Phishing()
>>> metric = metrics.Accuracy()
>>> for x, y in dataset:
... y_pred = model_pipeline.predict_one(x) # make a prediction
... metric = metric.update(y, y_pred) # update the metric
... model_pipeline = model_pipeline.learn_one(x, y) # update the model
>>> print(f"Accuracy: {metric.get():.2f}")
Accuracy: 0.44
"""

class LRModule(nn.Module):
def __init__(self, n_features):
super().__init__()
self.dense0 = nn.Linear(n_features, 1)
self.softmax = nn.Softmax(dim=-1)

def forward(self, X, **kwargs):
X = self.dense0(X)
return self.softmax(X)

def __init__(
self,
loss_fn: Union[str, Callable] = "binary_cross_entropy_with_logits",
optimizer_fn: Union[str, Callable] = "sgd",
lr: float = 1e-3,
output_is_logit: bool = True,
is_class_incremental: bool = False,
device: str = "cpu",
seed: int = 42,
**kwargs,
):
super().__init__(
module=LogisticRegression.LRModule,
loss_fn=loss_fn,
output_is_logit=output_is_logit,
is_class_incremental=is_class_incremental,
optimizer_fn=optimizer_fn,
device=device,
lr=lr,
seed=seed,
**kwargs,
)

@classmethod
def _unit_test_params(cls):
"""
Returns a dictionary of parameters to be used for unit testing the
respective class.
Yields
-------
dict
Dictionary of parameters to be used for unit testing the
respective class.
"""

yield {
"loss_fn": "binary_cross_entropy_with_logits",
"optimizer_fn": "sgd",
}


class MultiLayerPerceptron(Classifier):
"""
This class implements a logistic regression model in PyTorch.
Parameters
----------
n_width
Number of units in each hidden layer.
n_layers
Number of hidden layers.
loss_fn
Loss function to be used for training the wrapped model. Can be a
loss function provided by `torch.nn.functional` or one of the
following: 'mse', 'l1', 'cross_entropy',
'binary_cross_entropy_with_logits', 'binary_crossentropy',
'smooth_l1', 'kl_div'.
optimizer_fn
Optimizer to be used for training the wrapped model.
Can be an optimizer class provided by `torch.optim` or one of the
following: "adam", "adam_w", "sgd", "rmsprop", "lbfgs".
lr
Learning rate of the optimizer.
output_is_logit
Whether the module produces logits as output. If true, either
softmax or sigmoid is applied to the outputs when predicting.
is_class_incremental
Whether the classifier should adapt to the appearance of
previously unobserved classes by adding an unit to the output
layer of the network. This works only if the last trainable
layer is an nn.Linear layer. Note also, that output activation
functions can not be adapted, meaning that a binary classifier
with a sigmoid output can not be altered to perform multi-class
predictions.
device
Device to run the wrapped model on. Can be "cpu" or "cuda".
seed
Random seed to be used for training the wrapped model.
**kwargs
Parameters to be passed to the `build_fn` function aside from
`n_features`.
Examples
--------
>>> from deep_river.classification import MultiLayerPerceptron
>>> from river import metrics, preprocessing, compose, datasets
>>> from torch import nn, manual_seed
>>> _ = manual_seed(42)
>>> model_pipeline = compose.Pipeline(
... preprocessing.StandardScaler(),
... MultiLayerPerceptron(n_width=5,n_layers=5)
... )
>>> dataset = datasets.Phishing()
>>> metric = metrics.Accuracy()
>>> for x, y in dataset:
... y_pred = model_pipeline.predict_one(x) # make a prediction
... metric = metric.update(y, y_pred) # update the metric
... model_pipeline = model_pipeline.learn_one(x, y) # update the model
>>> print(f"Accuracy: {metric.get():.2f}")
Accuracy: 0.44
"""

class MLPModule(nn.Module):
def __init__(self, n_width, n_layers, n_features):
super().__init__()
self.dense0 = nn.Linear(n_features, n_width)
self.block = [nn.Linear(n_width, n_width) for _ in range(n_layers)]
self.denselast = nn.Linear(n_width, 1)
self.softmax = nn.Softmax(dim=-1)

def forward(self, X, **kwargs):
X = self.dense0(X)
for layer in self.block:
X = layer(X)
X = self.denselast(X)
return self.softmax(X)

def __init__(
self,
n_width: int = 5,
n_layers: int = 5,
loss_fn: Union[str, Callable] = "binary_cross_entropy_with_logits",
optimizer_fn: Union[str, Callable] = "sgd",
lr: float = 1e-3,
output_is_logit: bool = True,
is_class_incremental: bool = False,
device: str = "cpu",
seed: int = 42,
**kwargs,
):
self.n_width = n_width
self.n_layers = n_layers
super().__init__(
module=MultiLayerPerceptron.MLPModule,
loss_fn=loss_fn,
output_is_logit=output_is_logit,
is_class_incremental=is_class_incremental,
optimizer_fn=optimizer_fn,
device=device,
lr=lr,
seed=seed,
n_width=n_width,
n_layers=n_layers,
**kwargs,
)

@classmethod
def _unit_test_params(cls):
"""
Returns a dictionary of parameters to be used for unit testing the
respective class.
Yields
-------
dict
Dictionary of parameters to be used for unit testing the
respective class.
"""

yield {
"loss_fn": "binary_cross_entropy_with_logits",
"optimizer_fn": "sgd",
}
10 changes: 10 additions & 0 deletions deep_river/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from deep_river.regression.regressor import Regressor
from deep_river.regression.rolling_regressor import RollingRegressor

# isort: split
from deep_river.regression.multioutput import MultiTargetRegressor
from deep_river.regression.zoo import LinearRegression, MultiLayerPerceptron

"""
This module contains the regressors for the deep_river package.
"""
__all__ = [
"Regressor",
"RollingRegressor",
"MultiTargetRegressor",
"LinearRegression",
"MultiLayerPerceptron",
]
Loading

0 comments on commit 36b9b47

Please sign in to comment.