Skip to content

Commit

Permalink
Support One-Versus-One approach (#315)
Browse files Browse the repository at this point in the history
* Add first version of COneVersusOne

Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>

* Fix some bugs

Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>

* Add some tests

Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>

* Fix Q matrix formula

Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>

* Fix comment in tests

Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>

* Remove redundant interface

Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>

* Add multiclass modes

Signed-off-by: FedyuninV <valery.fedyunin@abbyy.com>

* Rename MM_SingleTree to MM_SingleClassifier

Signed-off-by: FedyuninV <valery.fedyunin@abbyy.com>

* Remove copy-paste from COneVersusOneModel::Classify

Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>

* Add OneVsOne approach to Python wrapper

Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>

* Add COneVersusOne docs

Signed-off-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>

* Mark findProb function as static

Signed-off-by: Valeriy Fedyunin <stelz40494@gmail.com>

* Fix typo in OneVersusOne.md

Signed-off-by: Valeriy Fedyunin <stelz40494@gmail.com>

* Add options to MulticlassMode param documentation

Signed-off-by: Valeriy Fedyunin <stelz40494@gmail.com>

Co-authored-by: Stanislav Angeliuk <59917951+SAngeliuk@users.noreply.github.com>
Co-authored-by: NeoML-maintainer <65914319+NeoML-maintainer@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 1, 2021
1 parent 76c2e33 commit 853e6f7
Show file tree
Hide file tree
Showing 31 changed files with 684 additions and 39 deletions.
10 changes: 8 additions & 2 deletions NeoML/Python/neoml/DecisionTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ class DecisionTreeClassifier(PythonWrapper.DecisionTree):
:param random_selected_feature_count: no more than this number of randomly selected features will be used for each node.
-1 means use all features every time.
:type random_selected_feature_count: int, default=-1
:param multiclass_mode: determines how to handle multi-class classification
:type multiclass_mode: str, ['single_tree', 'one_vs_all', 'one_vs_one'], default='single_tree'
"""

def __init__(self, criterion='gini', min_subset_size=1, min_subset_part=0.0, min_split_size=1, max_tree_depth=32,
max_node_count=4096, const_threshold=0.99, random_selected_feature_count=-1):
max_node_count=4096, const_threshold=0.99, random_selected_feature_count=-1, multiclass_mode='single_tree'):

if criterion != 'gini' and criterion != 'information_gain':
raise ValueError('The `criterion` must be one of: `gini`, `information_gain`.')
Expand All @@ -91,9 +94,12 @@ def __init__(self, criterion='gini', min_subset_size=1, min_subset_part=0.0, min
raise ValueError('The `const_threshold` must be in [0, 1].')
if random_selected_feature_count <= 0 and random_selected_feature_count != -1:
raise ValueError('The `random_selected_feature_count` must be > 0 or -1.')
if multiclass_mode != 'single_tree' and multiclass_mode != 'one_vs_all' and multiclass_mode != 'one_vs_one':
raise ValueError('The `multiclass_mode` must be one of: `single_tree`, `one_vs_all`, `one_vs_one`.')

super().__init__(int(min_subset_size), float(min_subset_part), int(min_split_size), int(max_tree_depth),
int(max_node_count), criterion, float(const_threshold), int(random_selected_feature_count))
int(max_node_count), criterion, float(const_threshold), int(random_selected_feature_count),
multiclass_mode)

def train(self, X, Y, weight=None):
"""Trains the decision tree.
Expand Down
12 changes: 9 additions & 3 deletions NeoML/Python/neoml/Linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ class LinearClassifier(PythonWrapper.Linear) :
:param thread_count: the number of threads to be used while training the model.
:type thread_count: int, default=1
:param multiclass_mode: determines how to handle multi-class classification
:type multiclass_mode: str, ['one_vs_all', 'one_vs_one'], default='one_vs_all'
"""

def __init__(self, loss='binomial', max_iteration_count=1000, error_weight=1.0,
sigmoid=(0.0, 0.0), tolerance=-1.0, normalizeError=False, l1_reg=0.0, thread_count=1):
sigmoid=(0.0, 0.0), tolerance=-1.0, normalizeError=False, l1_reg=0.0, thread_count=1,
multiclass_mode='one_vs_all'):

if loss != 'binomial' and loss != 'squared_hinge' and loss != 'smoothed_hinge':
raise ValueError('The `loss` must be one of: `binomial`, `squared_hinge`, `smoothed_hinge`.')
Expand All @@ -84,9 +88,11 @@ def __init__(self, loss='binomial', max_iteration_count=1000, error_weight=1.0,

if thread_count <= 0:
raise ValueError('The `thread_count` must be > 0.')
if multiclass_mode != 'one_vs_all' and multiclass_mode != 'one_vs_one':
raise ValueError('The `multiclass_mode` must be one of: `one_vs_all`, `one_vs_one`.')

super().__init__(loss, int(max_iteration_count), float(error_weight), float(sigmoid[0]), float(sigmoid[1]), float(tolerance), bool(normalizeError),
float(l1_reg), int(thread_count))
float(l1_reg), int(thread_count), multiclass_mode)

def train(self, X, Y, weight=None):
"""Trains the linear classification model.
Expand Down Expand Up @@ -191,7 +197,7 @@ def __init__(self, loss='l2', max_iteration_count=1000, error_weight=1.0,
raise ValueError('The `thread_count` must be > 0.')

super().__init__(loss, int(max_iteration_count), float(error_weight), float(sigmoid[0]), float(sigmoid[1]), float(tolerance), bool(normalizeError),
float(l1_reg), int(thread_count))
float(l1_reg), int(thread_count), '')

def train(self, X, Y, weight=None):
"""Trains the linear regression model.
Expand Down
9 changes: 7 additions & 2 deletions NeoML/Python/neoml/SVM.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,13 @@ class SvmClassifier(PythonWrapper.Svm) :
:param thread_count: The number of processing threads to be used while training the model.
:type thread_count: int, default=1
:param multiclass_mode: determines how to handle multi-class classification
:type multiclass_mode: str, ['one_vs_all', 'one_vs_one'], default='one_vs_all'
"""

def __init__(self, kernel='linear', max_iteration_count=1000, error_weight=1.0,
degree=1, gamma=1.0, coeff0=1.0, tolerance=0.1, thread_count=1):
degree=1, gamma=1.0, coeff0=1.0, tolerance=0.1, thread_count=1, multiclass_mode='one_vs_all'):

if kernel != 'linear' and kernel != 'poly' and kernel != 'rbf' and kernel != 'sigmoid':
raise ValueError('The `kernel` must be one of: `linear`, `poly`, `rbf`, `sigmoid`.')
Expand All @@ -79,9 +82,11 @@ def __init__(self, kernel='linear', max_iteration_count=1000, error_weight=1.0,

if thread_count <= 0:
raise ValueError('The `thread_count` must be > 0.')
if multiclass_mode != 'one_vs_all' and multiclass_mode != 'one_vs_one':
raise ValueError('The `multiclass_mode` must be one of: `one_vs_all`, `one_vs_one`.')

super().__init__(kernel, float(error_weight), int(max_iteration_count), int(degree),
float(gamma), float(coeff0), float(tolerance), int(thread_count))
float(gamma), float(coeff0), float(tolerance), int(thread_count), multiclass_mode)

def train(self, X, Y, weight=None):
"""Trains the SVM classification model.
Expand Down
26 changes: 23 additions & 3 deletions NeoML/Python/src/PyTrainingModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ void InitializeTrainingModel(py::module& m)
py::class_<CPyDecisionTree, CPyTrainingModel>(m, "DecisionTree")
.def(
py::init([]( int min_subset_size, float min_subset_part, int min_split_size, int max_tree_depth, int max_node_count, const std::string& criterion,
float const_threshold, int random_selected_feature_count )
float const_threshold, int random_selected_feature_count, const std::string& multiclass_mode )
{
CDecisionTreeTrainingModel::CParams p;
p.SplitCriterion = CDecisionTreeTrainingModel::SC_Count;
Expand All @@ -378,6 +378,14 @@ void InitializeTrainingModel(py::module& m)
p.ConstNodeThreshold = const_threshold;
p.RandomSelectedFeaturesCount = random_selected_feature_count;

if( multiclass_mode == "single_tree" ) {
p.MulticlassMode = MM_SingleClassifier;
} else if( multiclass_mode == "one_vs_all" ) {
p.MulticlassMode = MM_OneVsAll;
} else if( multiclass_mode == "one_vs_one" ) {
p.MulticlassMode = MM_OneVsOne;
}

return new CPyDecisionTree( p );
})
)
Expand All @@ -390,7 +398,7 @@ void InitializeTrainingModel(py::module& m)
py::class_<CPySvm, CPyTrainingModel>(m, "Svm")
.def( py::init(
[]( const std::string& kernel, float error_weight, int max_iteration_count, int degree, float gamma, float coeff0,
float tolerance, int thread_count ) {
float tolerance, int thread_count, const std::string& multiclass_mode ) {
CSvmBinaryClassifierBuilder::CParams p( CSvmKernel::KT_Undefined );
if( kernel == "linear" ) {
p.KernelType = CSvmKernel::KT_Linear;
Expand All @@ -409,6 +417,12 @@ void InitializeTrainingModel(py::module& m)
p.Tolerance = tolerance;
p.ThreadCount = thread_count;

if( multiclass_mode == "one_vs_all" ) {
p.MulticlassMode = MM_OneVsAll;
} else if( multiclass_mode == "one_vs_one" ) {
p.MulticlassMode = MM_OneVsOne;
}

return new CPySvm( p );
})
)
Expand All @@ -421,7 +435,7 @@ void InitializeTrainingModel(py::module& m)
py::class_<CPyLinear, CPyTrainingModel>(m, "Linear")
.def( py::init(
[]( const std::string& loss, int max_iteration_count, float error_weight, float sigmoid_a, float sigmoid_b,
float tolerance, bool normalize_error, float l1_reg, int thread_count ) {
float tolerance, bool normalize_error, float l1_reg, int thread_count, const std::string& multiclass_mode ) {
CLinearBinaryClassifierBuilder::CParams p( EF_Count );
if( loss == "smoothed_hinge" ) {
p.Function = EF_SmoothedHinge;
Expand All @@ -441,6 +455,12 @@ void InitializeTrainingModel(py::module& m)
p.L1Coeff = l1_reg;
p.ThreadCount = thread_count;

if( multiclass_mode == "one_vs_all" ) {
p.MulticlassMode = MM_OneVsAll;
} else if( multiclass_mode == "one_vs_one" ) {
p.MulticlassMode = MM_OneVsOne;
}

return new CPyLinear( p );
})
)
Expand Down
22 changes: 17 additions & 5 deletions NeoML/Python/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,10 +2133,14 @@ def _test_classification_model(self, model, params, is_binary=False):
X_sparse = sparse.csr_matrix(X_dense)
val = 1 if is_binary else 3
y = val * np.ones(20, dtype=np.int32)
if not is_binary: # every class should be represented in dataset
for i in range(3):
y[i] = i
weight = np.ones(20, dtype=np.float32)
for X in (X_dense, X_dense_list, X_sparse):
classifier = model(**params).train(X, y, weight)
pred = classifier.classify(X[0:3])
pred = classifier.classify(X[-3:])
print(pred, np.argmax(pred))
self.assertTrue(np.equal(np.argmax(pred), [val, val, val]).all())

def _test_regression_model(self, model, params):
Expand All @@ -2155,7 +2159,8 @@ def test_gradient_boosting_classification(self):
('binomial', 'exponential', 'squared_hinge', 'l2'),
('full', 'hist', 'multi_full'), (1, 4), (False, True)):
self._test_classification_model(neoml.GradientBoost.GradientBoostClassifier,
dict(loss=loss, iteration_count=10, builder_type=builder_type, thread_count=thread_count))
dict(loss=loss, iteration_count=10, builder_type=builder_type, thread_count=thread_count),
is_binary=is_binary)

def test_gradient_boosting_regression(self):
for builder_type, thread_count in itertools.product(('full', 'hist'), (1, 4)):
Expand All @@ -2165,19 +2170,26 @@ def test_gradient_boosting_regression(self):
def test_decision_tree_classification(self):
for criterion, is_binary in itertools.product(('gini', 'information_gain'), (False, True)):
self._test_classification_model(neoml.DecisionTree.DecisionTreeClassifier,
dict(criterion=criterion))
dict(criterion=criterion), is_binary=is_binary)
for multiclass_mode in ('single_tree', 'one_vs_all', 'one_vs_one'):
self._test_classification_model(neoml.DecisionTree.DecisionTreeClassifier, dict(multiclass_mode=multiclass_mode))

def test_svm_classification(self):
for kernel, thread_count, is_binary in itertools.product(('linear', 'poly', 'rbf', 'sigmoid'),
(1, 4), (False, True)):
self._test_classification_model(neoml.SVM.SvmClassifier,
dict(kernel=kernel, thread_count=thread_count))
dict(kernel=kernel, thread_count=thread_count), is_binary=is_binary)
for multiclass_mode in ('one_vs_all', 'one_vs_one'):
print('svm ', multiclass_mode)
self._test_classification_model(neoml.SVM.SvmClassifier, dict(multiclass_mode=multiclass_mode))

def test_linear_classification(self):
for loss, thread_count, is_binary in itertools.product(('binomial', 'squared_hinge', 'smoothed_hinge'),
(1, 4), (False, True)):
self._test_classification_model(neoml.Linear.LinearClassifier,
dict(loss=loss, thread_count=thread_count))
dict(loss=loss, thread_count=thread_count), is_binary=is_binary)
for multiclass_mode in ('one_vs_all', 'one_vs_one'):
self._test_classification_model(neoml.Linear.LinearClassifier, dict(multiclass_mode=multiclass_mode))

def test_linear_regression(self):
for thread_count in (1, 4):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The parameters are represented by a `CDecisionTree::CParams` structure.
- *SplitCriterion* — the criterion for subset splitting.
- *ConstNodeThreshold* — the ratio of the equal elements in the subset which should be the threshold for creating a constant node (may be from 0 to 1).
- *RandomSelectedFeaturesCount* — no more than this number of randomly selected features will be used for each node. Set the value to `-1` to use all features every time.
- *MulticlassMode* - the approach used in multiclass task: SingleClassifier (default), OneVsAll or OneVsOne.

## Model

Expand Down
1 change: 1 addition & 0 deletions NeoML/docs/en/API/ClassificationAndRegression/Linear.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The parameters are represented by a `CLinear::CParams` structure.
- *NormalizeError* — specifies if the error should be normalized.
- *L1Coeff* — the L1 regularization coefficient; set to `0` to use the L2 regularization instead.
- *ThreadCount* — the number of processing threads to be used while training the model.
- *MulticlassMode* - the approach used in multiclass task: OneVsAll (default) or OneVsOne.

### Loss function

Expand Down
42 changes: 42 additions & 0 deletions NeoML/docs/en/API/ClassificationAndRegression/OneVersusOne.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# One Versus One Classification COneVersusOne

<!-- TOC -->

- [One Versus One Classification COneVersusOne](#one-versus-one-classification-coneversusone)
- [Training settings](#training-settings)
- [Model](#model)
- [Classification result](#classification-result)
- [Sample](#sample)

<!-- /TOC -->

One vs. one method provides a way to solve a multi-class classification problem using only binary classifiers.

The original classification problem is represented as a series of binary classification problems, one for each pair of classes, that determine the pairwise probabilities for the object to belong to one class or another.

Afterfwards the optimal probabilities for each class are found by solving an optimization task, which is described in Section 4 of [this article](https://www.csie.ntu.edu.tw/~cjlin/papers/svmprob/svmprob.pdf).

In **NeoML** library this method is implemented by the `COneVersusOne` class. It exposes a `Train` method for creating a classification model.

## Training settings

The only parameter the algorithm requires is the pointer to the basic binary classification method, represented by an object that implements the [ITrainingModel](TrainingModels.md) interface.

## Model

The trained model is an ensemble of binary classification models. It implements the [`IModel` interface](Models.md#for-classification).

## Classification result

It provides the standard `Classify` method which writes the result into the given [`CClassificationResult`](README.md#classification-result).

## Sample

Here is a simple example of training a one-versus-one model using a linear binary classifier.

```c++
CLinear linear( EF_LogReg );

COneVersusOne oneVersusOne( linear );
CPtr<IModel> model = oneVersusOne.Train( *trainData );
```
6 changes: 6 additions & 0 deletions NeoML/docs/en/API/ClassificationAndRegression/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ This method helps solve a multi-class classification problem using only binary c

It is implemented by the [COneVersusAll](OneVersusAll.md) class. The trained multi-class classification model implements the `IOneVersusAllModel` interface.

### One versus one method

This method helps solve a multi-class classification problem using only binary classifiers.

It is implemented by the [COneVersusOne](OneVersusOne.md) class. The trained model implements the `IModel` interface.

## Auxiliary interfaces

All the methods for model training implement common interfaces, accept the input data of the same type and train models that may be accessed using the common interface.
Expand Down
3 changes: 2 additions & 1 deletion NeoML/docs/en/API/ClassificationAndRegression/Svm.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ The parameters are represented by a `CSvm::CParams` structure.
- *Coeff0* — the kernel free term (for `KT_Poly`, `KT_Sigmoid`)
- *Tolerance* — the algorithm precision, the stop criterion
- *ThreadCount* — the number of processing threads to be used while training
- *MulticlassMode* - the approach used in multiclass task: OneVsAll (default) or OneVsOne

## Model

The trained model implements the [`ILinearBinaryModel`](Linear.md#for-classification) interface if a `KT_Linear` kernel is used; the [`IOneVersusAllModel`](OneVersusAll.md#model) if number of classes > 2; otherwise, it implements the `ISvmBinaryModel` interface.
The trained model implements the [`ILinearBinaryModel`](Linear.md#for-classification) interface if a `KT_Linear` kernel is used; or `MuticlassMode` model if number of classes > 2; otherwise, it implements the `ISvmBinaryModel` interface.

```c++
// SVM binary classifier interface
Expand Down
4 changes: 2 additions & 2 deletions NeoML/docs/en/Tutorial/News20Classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

This tutorial walks through training **NeoML** classification model to classify the well-known [News20](https://archive.ics.uci.edu/ml/datasets/Twenty+Newsgroups) data set.

We are going to use the [linear classifier](../API/ClassificationAndRegression/Linear.md) that implicitly uses ["one versus all"](../API/ClassificationAndRegression/OneVersusAll.md) method.
We are going to use the [linear classifier](../API/ClassificationAndRegression/Linear.md) that by default will use ["one versus all"](../API/ClassificationAndRegression/OneVersusAll.md) method for multiclasstask.

## Preparing the input data

Expand All @@ -36,7 +36,7 @@ testArchive >> testData;
The "one versus all" method uses the specified classifier to train a model per each class that would determine the probability for an object to belong to this class. An input object is then classified by the models voting.
1. Create a linear classifier using the `CLinear` class (`COneVersusAll` will take place implicitly). Select the logistic regression loss function (`EF_LogReg` constant).
1. Create a linear classifier using the `CLinear` class (by default `COneVersusAll` will be used for multiclass task). Select the logistic regression loss function (`EF_LogReg` constant).
2. Call the `Train` method, passing the `trainData` training set prepared above. The method will train the model and return it as an object implementing the `IModel` interface.
```c++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
- *MaxNodesCount* — максимальное количество вершин в дереве;
- *SplitCriterion* — критерий деления подмножеств при построении дерева;
- *ConstNodeThreshold* — доля одинаковых элементов в подмножестве, при превышении которой будет создана константная вершина (может принимать значения от 0 до 1);
- *RandomSelectedFeaturesCount* — при построении каждого узла используется не больше этого количества случайно выбранных признаков. Задайте значение `-1`, чтобы использовать все признаки.
- *RandomSelectedFeaturesCount* — при построении каждого узла используется не больше этого количества случайно выбранных признаков. Задайте значение `-1`, чтобы использовать все признаки;
- *MulticlassMode* - подход, используемый при многоклассовой классификации: SingleClassifier (по умолчанию), OneVsAll или OneVsOne.

## Модель

Expand Down
Loading

0 comments on commit 853e6f7

Please sign in to comment.