Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new model ae1svm & updated author info #592

Merged
merged 5 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add new algorithm ae1svm
  • Loading branch information
zhuox5 committed Jun 25, 2024
commit 0d3689433beca6070d6aed15350df0828940957b
52 changes: 52 additions & 0 deletions examples/ae1svm_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
"""Example of using AE1SVM for outlier detection (pytorch)
"""
# Author: Zhuo Xiao

from __future__ import division, print_function

import os
import sys

# temporary solution for relative imports in case pyod is not installed
# if pyod is installed, no need to use the following line
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname("__file__"), '..')))

from pyod.models.ae1svm import AE1SVM
from pyod.utils.data import generate_data
from pyod.utils.data import evaluate_print


if __name__ == "__main__":
contamination = 0.1 # percentage of outliers
n_train = 20000 # number of training points
n_test = 2000 # number of testing points
n_features = 300 # number of features

# Generate sample data
X_train, X_test, y_train, y_test = \
generate_data(n_train=n_train,
n_test=n_test,
n_features=n_features,
contamination=contamination,
random_state=42)

# train AE1SVM detector
clf_name = 'AE1SVM'
clf = AE1SVM(epochs=10)
clf.fit(X_train)

# get the prediction labels and outlier scores of the training data
y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers)
y_train_scores = clf.decision_scores_ # raw outlier scores

# get the prediction on the test data
y_test_pred = clf.predict(X_test) # outlier labels (0 or 1)
y_test_scores = clf.decision_function(X_test) # outlier scores

# evaluate and print the results
print("\nOn Training Data:")
evaluate_print(clf_name, y_train, y_train_scores)
print("\nOn Test Data:")
evaluate_print(clf_name, y_test, y_test_scores)
183 changes: 183 additions & 0 deletions pyod/models/ae1svm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
"""Using AE-1SVM with Outlier Detection (PyTorch)
Source: https://arxiv.org/pdf/1804.04888
"""
# Author: Zhuo Xiao

from __future__ import division, print_function

import numpy as np
import torch
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted
from torch import nn

from .base import BaseDetector
from ..utils.stat_models import pairwise_distances_no_broadcast
from ..utils.torch_utility import get_activation_by_name


class PyODDataset(torch.utils.data.Dataset):
"""PyOD Dataset class for PyTorch Dataloader"""

def __init__(self, X, y=None, mean=None, std=None):
super(PyODDataset, self).__init__()
self.X = X
self.mean = mean
self.std = std

def __len__(self):
return self.X.shape[0]

def __getitem__(self, idx):
sample = self.X[idx, :]
if self.mean is not None and self.std is not None:
sample = (sample - self.mean) / self.std
return torch.from_numpy(sample), idx


class InnerAE1SVM(nn.Module):
def __init__(self, n_features, encoding_dim, rff_dim, sigma=1.0, hidden_neurons=(128, 64),
dropout_rate=0.2, batch_norm=True, hidden_activation='relu'):
super(InnerAE1SVM, self).__init__()

self.encoder = nn.Sequential()
self.decoder = nn.Sequential()
self.rff = RandomFourierFeatures(encoding_dim, rff_dim, sigma)
self.svm_weights = nn.Parameter(torch.randn(rff_dim))
self.svm_bias = nn.Parameter(torch.randn(1))

activation = get_activation_by_name(hidden_activation)
layers_neurons_encoder = [n_features, *hidden_neurons, encoding_dim]

for idx in range(len(layers_neurons_encoder) - 1):
self.encoder.add_module(f"linear{idx}", nn.Linear(layers_neurons_encoder[idx], layers_neurons_encoder[idx + 1]))
if batch_norm:
self.encoder.add_module(f"batch_norm{idx}", nn.BatchNorm1d(layers_neurons_encoder[idx + 1]))
self.encoder.add_module(f"activation{idx}", activation)
self.encoder.add_module(f"dropout{idx}", nn.Dropout(dropout_rate))

layers_neurons_decoder = layers_neurons_encoder[::-1]

for idx in range(len(layers_neurons_decoder) - 1):
self.decoder.add_module(f"linear{idx}", nn.Linear(layers_neurons_decoder[idx], layers_neurons_decoder[idx + 1]))
if batch_norm and idx < len(layers_neurons_decoder) - 2:
self.decoder.add_module(f"batch_norm{idx}", nn.BatchNorm1d(layers_neurons_decoder[idx + 1]))
self.decoder.add_module(f"activation{idx}", activation)
if idx < len(layers_neurons_decoder) - 2:
self.decoder.add_module(f"dropout{idx}", nn.Dropout(dropout_rate))

def forward(self, x):
x = self.encoder(x)
rff_features = self.rff(x)
x = self.decoder(x)
return x, rff_features

def svm_decision_function(self, rff_features):
return torch.matmul(rff_features, self.svm_weights) + self.svm_bias


class RandomFourierFeatures(nn.Module):
def __init__(self, input_dim, output_dim, sigma=1.0):
super(RandomFourierFeatures, self).__init__()
self.weights = nn.Parameter(torch.randn(input_dim, output_dim) * sigma)
self.bias = nn.Parameter(torch.randn(output_dim) * 2 * np.pi)

def forward(self, x):
x = torch.matmul(x, self.weights) + self.bias
return torch.cos(x)


class AE1SVM(BaseDetector):
"""Auto Encoder with One-class SVM for anomaly detection."""

def __init__(self, hidden_neurons=None, hidden_activation='relu',
batch_norm=True, learning_rate=1e-3, epochs=100, batch_size=32,
dropout_rate=0.2, weight_decay=1e-5, preprocessing=True,
loss_fn=None, contamination=0.1, device=None, alpha=1.0, sigma=1.0, nu=0.1, kernel_approx_features=1000):
super(AE1SVM, self).__init__(contamination=contamination)

self.model = None
self.decision_scores_ = None
self.std = None
self.mean = None
self.hidden_neurons = hidden_neurons
self.hidden_activation = hidden_activation
self.batch_norm = batch_norm
self.learning_rate = learning_rate
self.epochs = epochs
self.batch_size = batch_size
self.dropout_rate = dropout_rate
self.weight_decay = weight_decay
self.preprocessing = preprocessing
self.loss_fn = loss_fn or torch.nn.MSELoss()
self.device = device or torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.hidden_neurons = hidden_neurons or [64, 32]
self.alpha = alpha
self.sigma = sigma
self.nu = nu
self.kernel_approx_features = kernel_approx_features

def fit(self, X, y=None):
X = check_array(X)
self._set_n_classes(y)

n_samples, n_features = X.shape
if self.preprocessing:
self.mean, self.std = np.mean(X, axis=0), np.std(X, axis=0)
train_set = PyODDataset(X=X, mean=self.mean, std=self.std)
else:
train_set = PyODDataset(X=X)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=self.batch_size, shuffle=True)
self.model = InnerAE1SVM(n_features=n_features, encoding_dim=32, rff_dim=self.kernel_approx_features, sigma=self.sigma,
hidden_neurons=self.hidden_neurons, dropout_rate=self.dropout_rate,
batch_norm=self.batch_norm, hidden_activation=self.hidden_activation)
self.model = self.model.to(self.device)
self._train_autoencoder(train_loader)

self.model.load_state_dict(self.best_model_dict)
self.decision_scores_ = self.decision_function(X)
self._process_decision_scores()
return self

def _train_autoencoder(self, train_loader):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
self.best_loss = float('inf')
self.best_model_dict = None

for epoch in range(self.epochs):
overall_loss = []
for data, data_idx in train_loader:
data = data.to(self.device).float()
reconstructions, rff_features = self.model(data)
recon_loss = self.loss_fn(data, reconstructions)
svm_scores = self.model.svm_decision_function(rff_features)
svm_loss = torch.mean(torch.clamp(1 - svm_scores, min=0))

loss = self.alpha * recon_loss + svm_loss
self.model.zero_grad()
loss.backward()
optimizer.step()
overall_loss.append(loss.item())
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {np.mean(overall_loss)}')

if np.mean(overall_loss) < self.best_loss:
self.best_loss = np.mean(overall_loss)
self.best_model_dict = self.model.state_dict()

def decision_function(self, X):
check_is_fitted(self, ['model', 'best_model_dict'])
X = check_array(X)
dataset = PyODDataset(X=X, mean=self.mean, std=self.std) if self.preprocessing else PyODDataset(X=X)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
self.model.eval()

outlier_scores = np.zeros([X.shape[0], ])
with torch.no_grad():
for data, data_idx in dataloader:
data = data.to(self.device).float()
reconstructions, rff_features = self.model(data)
scores = pairwise_distances_no_broadcast(data.cpu().numpy(), reconstructions.cpu().numpy())
outlier_scores[data_idx] = scores
return outlier_scores
117 changes: 117 additions & 0 deletions pyod/test/test_ae1svm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
"""Test AE1SVM for outlier detection (pytorch)
"""
from __future__ import division, print_function

import os
import sys
import unittest

from numpy.testing import assert_equal, assert_raises
from sklearn.base import clone
from sklearn.metrics import roc_auc_score

# temporary solution for relative imports in case pyod is not installed
# if pyod is installed, no need to use the following line
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from pyod.models.ae1svm import AE1SVM
from pyod.utils.data import generate_data


class TestAE1SVM(unittest.TestCase):

def setUp(self):
self.n_train = 6000
self.n_test = 1000
self.n_features = 300
self.contamination = 0.1
self.roc_floor = 0.8

self.X_train, self.X_test, self.y_train, self.y_test = generate_data(
n_train=self.n_train, n_test=self.n_test,
n_features=self.n_features, contamination=self.contamination,
random_state=42)

self.clf = AE1SVM(epochs=5, contamination=self.contamination)
self.clf.fit(self.X_train)

def test_parameters(self):
assert hasattr(self.clf, 'decision_scores_') and self.clf.decision_scores_ is not None
assert hasattr(self.clf, 'labels_') and self.clf.labels_ is not None
assert hasattr(self.clf, 'threshold_') and self.clf.threshold_ is not None
assert hasattr(self.clf, '_mu') and self.clf._mu is not None
assert hasattr(self.clf, '_sigma') and self.clf._sigma is not None
assert hasattr(self.clf, 'model') and self.clf.model is not None

def test_train_scores(self):
assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0])

def test_prediction_scores(self):
pred_scores = self.clf.decision_function(self.X_test)

# check score shapes
assert_equal(pred_scores.shape[0], self.X_test.shape[0])

# check performance
assert roc_auc_score(self.y_test, pred_scores) >= self.roc_floor

def test_prediction_labels(self):
pred_labels = self.clf.predict(self.X_test)
assert_equal(pred_labels.shape, self.y_test.shape)

def test_prediction_proba(self):
pred_proba = self.clf.predict_proba(self.X_test)
assert pred_proba.min() >= 0
assert pred_proba.max() <= 1

def test_prediction_proba_linear(self):
pred_proba = self.clf.predict_proba(self.X_test, method='linear')
assert pred_proba.min() >= 0
assert pred_proba.max() <= 1

def test_prediction_proba_unify(self):
pred_proba = self.clf.predict_proba(self.X_test, method='unify')
assert pred_proba.min() >= 0
assert pred_proba.max() <= 1

def test_prediction_proba_parameter(self):
with assert_raises(ValueError):
self.clf.predict_proba(self.X_test, method='something')

def test_prediction_labels_confidence(self):
pred_labels, confidence = self.clf.predict(self.X_test, return_confidence=True)
assert_equal(pred_labels.shape, self.y_test.shape)
assert_equal(confidence.shape, self.y_test.shape)
assert confidence.min() >= 0
assert confidence.max() <= 1

def test_prediction_proba_linear_confidence(self):
pred_proba, confidence = self.clf.predict_proba(self.X_test, method='linear', return_confidence=True)
assert pred_proba.min() >= 0
assert pred_proba.max() <= 1
assert_equal(confidence.shape, self.y_test.shape)
assert confidence.min() >= 0
assert confidence.max() <= 1

def test_fit_predict(self):
pred_labels = self.clf.fit_predict(self.X_train)
assert_equal(pred_labels.shape, self.y_train.shape)

def test_fit_predict_score(self):
self.clf.fit_predict_score(self.X_test, self.y_test)
self.clf.fit_predict_score(self.X_test, self.y_test, scoring='roc_auc_score')
self.clf.fit_predict_score(self.X_test, self.y_test, scoring='prc_n_score')
with assert_raises(NotImplementedError):
self.clf.fit_predict_score(self.X_test, self.y_test, scoring='something')

def test_model_clone(self):
# for deep models this may not apply
clone_clf = clone(self.clf)

def tearDown(self):
pass


if __name__ == '__main__':
unittest.main()