diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 98864fd0ec..4254e6fb21 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -52,6 +52,8 @@ tree which is trained for both regression and classification. (See {ml-pull}811[ (See {ml-pull}818[#818].) * Reduce memory usage of {ml} native processes on Windows. (See {ml-pull}844[#844].) * Reduce runtime of classification and regression. (See {ml-pull}863[#863].) +* Emit `prediction_field_name` in ml results using the type of a `dependent_variable`. +(See {ml-pull}877[#877].) === Bug Fixes * Fixes potential memory corruption when determining seasonality. (See {ml-pull}852[#852].) diff --git a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h index deed15b7cf..3f0964af98 100644 --- a/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h +++ b/include/api/CDataFrameTrainBoostedTreeClassifierRunner.h @@ -44,6 +44,10 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final const TRowRef& row, core::CRapidJsonConcurrentLineWriter& writer) const; + //! Write the predicted category value as string, int or bool. + void writePredictedCategoryValue(const std::string& categoryValue, + core::CRapidJsonConcurrentLineWriter& writer) const; + //! \return A serialisable definition of the trained classification model. TInferenceModelDefinitionUPtr inferenceModelDefinition(const TStrVec& fieldNames, @@ -55,6 +59,7 @@ class API_EXPORT CDataFrameTrainBoostedTreeClassifierRunner final private: std::size_t m_NumTopClasses; + std::string m_DependentVariableType; }; //! \brief Makes a core::CDataFrame boosted tree classification runner. diff --git a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc index f4609d595b..6541451ca1 100644 --- a/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc +++ b/lib/api/CDataFrameTrainBoostedTreeClassifierRunner.cc @@ -32,6 +32,7 @@ using TSizeVec = std::vector; // Configuration const std::string NUM_TOP_CLASSES{"num_top_classes"}; +const std::string DEPENDENT_VARIABLE_TYPE{"dependent_variable_type"}; const std::string BALANCED_CLASS_LOSS{"balanced_class_loss"}; // Output @@ -47,6 +48,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::parameterReader() { static const CDataFrameAnalysisConfigReader PARAMETER_READER{[] { auto theReader = CDataFrameTrainBoostedTreeRunner::parameterReader(); theReader.addParameter(NUM_TOP_CLASSES, CDataFrameAnalysisConfigReader::E_OptionalParameter); + theReader.addParameter(DEPENDENT_VARIABLE_TYPE, + CDataFrameAnalysisConfigReader::E_OptionalParameter); theReader.addParameter(BALANCED_CLASS_LOSS, CDataFrameAnalysisConfigReader::E_OptionalParameter); return theReader; @@ -60,6 +63,8 @@ CDataFrameTrainBoostedTreeClassifierRunner::CDataFrameTrainBoostedTreeClassifier : CDataFrameTrainBoostedTreeRunner{spec, parameters} { m_NumTopClasses = parameters[NUM_TOP_CLASSES].fallback(std::size_t{0}); + m_DependentVariableType = + parameters[DEPENDENT_VARIABLE_TYPE].fallback(std::string("string")); this->boostedTreeFactory().balanceClassTrainingLoss( parameters[BALANCED_CLASS_LOSS].fallback(true)); @@ -119,7 +124,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( writer.StartObject(); writer.Key(this->predictionFieldName()); - writer.String(categoryValues[predictedCategoryId]); + writePredictedCategoryValue(categoryValues[predictedCategoryId], writer); writer.Key(PREDICTION_PROBABILITY_FIELD_NAME); writer.Double(probabilityOfCategory[predictedCategoryId]); writer.Key(IS_TRAINING_FIELD_NAME); @@ -135,7 +140,7 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( for (std::size_t i = 0; i < std::min(categoryIds.size(), m_NumTopClasses); ++i) { writer.StartObject(); writer.Key(CLASS_NAME_FIELD_NAME); - writer.String(categoryValues[categoryIds[i]]); + writePredictedCategoryValue(categoryValues[categoryIds[i]], writer); writer.Key(CLASS_PROBABILITY_FIELD_NAME); writer.Double(probabilityOfCategory[i]); writer.EndObject(); @@ -158,6 +163,19 @@ void CDataFrameTrainBoostedTreeClassifierRunner::writeOneRow( columnHoldingPrediction, row, writer); } +void CDataFrameTrainBoostedTreeClassifierRunner::writePredictedCategoryValue( + const std::string& categoryValue, + core::CRapidJsonConcurrentLineWriter& writer) const { + + if (m_DependentVariableType == "int") { + writer.Int(std::stoi(categoryValue)); + } else if (m_DependentVariableType == "bool") { + writer.Bool(std::stoi(categoryValue) == 1); + } else { + writer.String(categoryValue); + } +} + CDataFrameTrainBoostedTreeClassifierRunner::TLossFunctionUPtr CDataFrameTrainBoostedTreeClassifierRunner::chooseLossFunction(const core::CDataFrame& frame, std::size_t dependentVariableColumn) const { diff --git a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc index dc60bf8bbf..241bba3bc5 100644 --- a/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc +++ b/lib/api/unittest/CDataFrameTrainBoostedTreeClassifierRunnerTest.cc @@ -45,15 +45,20 @@ BOOST_AUTO_TEST_CASE(testPredictionFieldNameClash) { BOOST_TEST_REQUIRE(errors[0] == "Input error: prediction_field_name must not be equal to any of [is_training, prediction_probability, top_classes]."); } -BOOST_AUTO_TEST_CASE(testWriteOneRow) { +template +void testWriteOneRow(const std::string& dependentVariableField, + const std::string& dependentVariableType, + T (rapidjson::Value::*extract)() const, + const std::vector& expectedPredictions) { // Prepare input data frame - const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", "x5_prediction"}; - const TStrVec categoricalColumns{"x1", "x2", "x5"}; + const std::string predictionField = dependentVariableField + "_prediction"; + const TStrVec columnNames{"x1", "x2", "x3", "x4", "x5", predictionField}; + const TStrVec categoricalColumns{"x1", "x2", "x3", "x4", "x5"}; const TStrVecVec rows{{"a", "b", "1.0", "1.0", "cat", "-1.0"}, - {"a", "b", "2.0", "2.0", "cat", "-0.5"}, - {"a", "b", "5.0", "5.0", "dog", "-0.1"}, - {"c", "d", "5.0", "5.0", "dog", "1.0"}, - {"e", "f", "5.0", "5.0", "dog", "1.5"}}; + {"a", "b", "1.0", "1.0", "cat", "-0.5"}, + {"a", "b", "5.0", "0.0", "dog", "-0.1"}, + {"c", "d", "5.0", "0.0", "dog", "1.0"}, + {"e", "f", "5.0", "0.0", "dog", "1.5"}}; std::unique_ptr frame = core::makeMainStorageDataFrame(columnNames.size()).first; frame->columnNames(columnNames); @@ -67,10 +72,13 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { // Create classification analysis runner object const auto spec{test::CDataFrameAnalysisSpecificationFactory::predictionSpec( - "classification", "x5", rows.size(), columnNames.size(), 13000000, 0, 0, - categoricalColumns)}; + "classification", dependentVariableField, rows.size(), + columnNames.size(), 13000000, 0, 0, categoricalColumns)}; rapidjson::Document jsonParameters; - jsonParameters.Parse("{\"dependent_variable\": \"x5\"}"); + jsonParameters.Parse("{" + " \"dependent_variable\": \"" + dependentVariableField + "\"," + " \"dependent_variable_type\": \"" + dependentVariableType + "\"" + "}"); const auto parameters{ api::CDataFrameTrainBoostedTreeClassifierRunner::parameterReader().read(jsonParameters)}; api::CDataFrameTrainBoostedTreeClassifierRunner runner(*spec, parameters); @@ -83,10 +91,10 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) { const auto columnHoldingDependentVariable{ - std::find(columnNames.begin(), columnNames.end(), "x5") - + std::find(columnNames.begin(), columnNames.end(), dependentVariableField) - columnNames.begin()}; const auto columnHoldingPrediction{ - std::find(columnNames.begin(), columnNames.end(), "x5_prediction") - + std::find(columnNames.begin(), columnNames.end(), predictionField) - columnNames.begin()}; for (auto row = beginRows; row != endRows; ++row) { runner.writeOneRow(*frame, columnHoldingDependentVariable, @@ -95,17 +103,17 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { }); } // Verify results - const TStrVec expectedPredictions{"cat", "cat", "cat", "dog", "dog"}; rapidjson::Document arrayDoc; arrayDoc.Parse(output.str().c_str()); BOOST_TEST_REQUIRE(arrayDoc.IsArray()); BOOST_TEST_REQUIRE(arrayDoc.Size() == rows.size()); + BOOST_TEST_REQUIRE(arrayDoc.Size() == expectedPredictions.size()); for (std::size_t i = 0; i < arrayDoc.Size(); ++i) { BOOST_TEST_CONTEXT("Result for row " << i) { const rapidjson::Value& object = arrayDoc[rapidjson::SizeType(i)]; BOOST_TEST_REQUIRE(object.IsObject()); - BOOST_TEST_REQUIRE(object.HasMember("x5_prediction")); - BOOST_TEST_REQUIRE(object["x5_prediction"].GetString() == + BOOST_TEST_REQUIRE(object.HasMember(predictionField)); + BOOST_TEST_REQUIRE((object[predictionField].*extract)() == expectedPredictions[i]); BOOST_TEST_REQUIRE(object.HasMember("prediction_probability")); BOOST_TEST_REQUIRE(object["prediction_probability"].GetDouble() > 0.5); @@ -115,4 +123,23 @@ BOOST_AUTO_TEST_CASE(testWriteOneRow) { } } +BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsInt) { + testWriteOneRow("x3", "int", &rapidjson::Value::GetInt, {1, 1, 1, 5, 5}); +} + +BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsBool) { + testWriteOneRow("x4", "bool", &rapidjson::Value::GetBool, + {true, true, true, false, false}); +} + +BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableIsString) { + testWriteOneRow("x5", "string", &rapidjson::Value::GetString, + {"cat", "cat", "cat", "dog", "dog"}); +} + +BOOST_AUTO_TEST_CASE(testWriteOneRow_DependentVariableTypeMissing) { + testWriteOneRow("x5", "", &rapidjson::Value::GetString, + {"cat", "cat", "cat", "dog", "dog"}); +} + BOOST_AUTO_TEST_SUITE_END()