Skip to content

Latest commit



73 lines (50 loc) · 2.89 KB

File metadata and controls

73 lines (50 loc) · 2.89 KB

Multi-Class Classification Sample

This tutorial walks through training NeoML classification model to classify the well-known News20 data set.

We are going to use the combination of linear binary classifier and the "one versus all" method.

Preparing the input data

We assume that the data set is split into two parts: train and test, and each is serialized in a file on disk as a CMemoryProblem (which is a simple implementation of the IProblem interface provided in the library).

The library serialization methods can be used to load the data into memory for processing.

CPtr<CMemoryProblem> trainData = new CMemoryProblem();
CPtr<CMemoryProblem> testData = new CMemoryProblem();

CArchiveFile trainFile( "news20.train", CArchive::load );
CArchive trainArchive( &trainFile, CArchive::load );
trainArchive >> trainData;

CArchiveFile testFile( "news20.test", CArchive::load );
CArchive testArchive( &testFile, CArchive::load );
testArchive >> testData;

Training the classifier

The "one versus all" classifier uses the specified binary classifier to train a model per each class that would determine the probability for an object to belong to this class. An input object is then classified by the models voting.

  1. Create a linear binary classifier using the CLinearBinaryClassifierBuilder class. Select the logistic regression loss function (EF_LogReg constant).
  2. Create a COneVersusAll classifier, passing the binary classifier set up on the previous step to the constructor.
  3. Call the Train method, passing the trainData training set prepared above. The method will train the model and return it as an object implementing the IModel interface.
CLinearBinaryClassifierBuilder linear( EF_LogReg );
COneVersusAll oneVersusAll( linear );
CPtr<IModel> model = oneVersusAll.Train( *trainData );

Analyzing the results

We can check the results the trained model shows on the test sample using the Classify method of the IModel interface. Call this method for each vector of the testData data set prepared before.

int correct = 0;
for( int i = 0; i < testData->GetVectorCount(); i++ ) {
	CClassificationResult result;
	model->Classify( testData->GetVector( i ), result );

	if( result.PreferredClass == testData->GetClass( i ) ) {

double totalResult = static_cast<double>(correct) / testData->GetVectorCount();
printf("%.3f\n", totalResult);

On this testing run, 83.3% of the vectors were classified correctly.
