Skip to content

Commit

Permalink
Added predictor class for AssistRanker NN classifier.
Browse files Browse the repository at this point in the history
Added a predictor class for the new NN classifier in AssistRanker. This
makes it possible to use downloadable unquantized NN classifiers for
inferencing.

Bug: 907727
Change-Id: Ib1e5647e3792f10757c17968a6b0b9b198a8d620
Reviewed-on: https://chromium-review.googlesource.com/c/1399768
Commit-Queue: Jon Napper <napper@chromium.org>
Reviewed-by: Charles . <charleszhao@chromium.org>
Cr-Commit-Position: refs/heads/master@{#625104}
  • Loading branch information
Jon Napper authored and Commit Bot committed Jan 23, 2019
1 parent 1d49df6 commit 63366d9
Show file tree
Hide file tree
Showing 9 changed files with 440 additions and 11 deletions.
3 changes: 3 additions & 0 deletions components/assist_ranker/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ static_library("assist_ranker") {
"base_predictor.h",
"binary_classifier_predictor.cc",
"binary_classifier_predictor.h",
"classifier_predictor.cc",
"classifier_predictor.h",
"example_preprocessing.cc",
"example_preprocessing.h",
"fake_ranker_model_loader.cc",
Expand Down Expand Up @@ -56,6 +58,7 @@ source_set("unit_tests") {
sources = [
"base_predictor_unittest.cc",
"binary_classifier_predictor_unittest.cc",
"classifier_predictor_unittest.cc",
"example_preprocessing_unittest.cc",
"generic_logistic_regression_inference_unittest.cc",
"nn_classifier_test_util.cc",
Expand Down
114 changes: 114 additions & 0 deletions components/assist_ranker/classifier_predictor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/assist_ranker/classifier_predictor.h"

#include <memory>
#include <utility>
#include <vector>

#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/files/file_path.h"
#include "components/assist_ranker/example_preprocessing.h"
#include "components/assist_ranker/nn_classifier.h"
#include "components/assist_ranker/proto/ranker_model.pb.h"
#include "components/assist_ranker/ranker_model.h"
#include "components/assist_ranker/ranker_model_loader_impl.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"

namespace assist_ranker {

ClassifierPredictor::ClassifierPredictor(const PredictorConfig& config)
: BasePredictor(config){};
ClassifierPredictor::~ClassifierPredictor(){};

// static
std::unique_ptr<ClassifierPredictor> ClassifierPredictor::Create(
const PredictorConfig& config,
const base::FilePath& model_path,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
std::unique_ptr<ClassifierPredictor> predictor(
new ClassifierPredictor(config));
if (!predictor->is_query_enabled()) {
DVLOG(1) << "Query disabled, bypassing model loading.";
return predictor;
}
const GURL& model_url = predictor->GetModelUrl();
DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
DVLOG(1) << "Model URL: " << model_url;
auto model_loader = std::make_unique<RankerModelLoaderImpl>(
base::BindRepeating(&ClassifierPredictor::ValidateModel),
base::BindRepeating(&ClassifierPredictor::OnModelAvailable,
base::Unretained(predictor.get())),
url_loader_factory, model_path, model_url, config.uma_prefix);
predictor->LoadModel(std::move(model_loader));
return predictor;
}

bool ClassifierPredictor::Predict(const std::vector<float>& features,
std::vector<float>* prediction) {
if (!IsReady()) {
DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
return false;
}

*prediction = nn_classifier::Inference(model_, features);
return true;
}

bool ClassifierPredictor::Predict(RankerExample example,
std::vector<float>* prediction) {
if (!IsReady()) {
DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
return false;
}

if (!model_.has_preprocessor_config()) {
DVLOG(1) << "No preprocessor config specified.";
return false;
}

const int preprocessor_error =
ExamplePreprocessor::Process(model_.preprocessor_config(), &example);

// It is okay to ignore cases where there is an extra feature that is not in
// the config.
if (preprocessor_error != ExamplePreprocessor::kSuccess &&
preprocessor_error != ExamplePreprocessor::kNoFeatureIndexFound) {
DVLOG(1) << "Preprocessing error " << preprocessor_error;
return false;
}

const auto& vec =
example.features()
.at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName)
.float_list()
.float_value();
const std::vector<float> features(vec.begin(), vec.end());
return Predict(features, prediction);
}

// static
RankerModelStatus ClassifierPredictor::ValidateModel(const RankerModel& model) {
if (model.proto().model_case() != RankerModelProto::kNnClassifier) {
DVLOG(0) << "Model is incompatible.";
return RankerModelStatus::INCOMPATIBLE;
}
return nn_classifier::Validate(model.proto().nn_classifier())
? RankerModelStatus::OK
: RankerModelStatus::INCOMPATIBLE;
}

bool ClassifierPredictor::Initialize() {
if (ranker_model_->proto().model_case() == RankerModelProto::kNnClassifier) {
model_ = ranker_model_->proto().nn_classifier();
return true;
}

DVLOG(0) << "Could not initialize inference module.";
return false;
}

} // namespace assist_ranker
68 changes: 68 additions & 0 deletions components/assist_ranker/classifier_predictor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef COMPONENTS_ASSIST_RANKER_CLASSIFIER_PREDICTOR_H_
#define COMPONENTS_ASSIST_RANKER_CLASSIFIER_PREDICTOR_H_

#include <memory>
#include <vector>

#include "base/compiler_specific.h"
#include "components/assist_ranker/base_predictor.h"
#include "components/assist_ranker/nn_classifier.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"

namespace base {
class FilePath;
}

namespace network {
class SharedURLLoaderFactory;
}

namespace assist_ranker {

// Predictor class for single-layer neural network models.
class ClassifierPredictor : public BasePredictor {
public:
~ClassifierPredictor() override;

// Returns an new predictor instance with the given |config| and initialize
// its model loader. The |request_context getter| is passed to the
// predictor's model_loader which holds it as scoped_refptr.
static std::unique_ptr<ClassifierPredictor> Create(
const PredictorConfig& config,
const base::FilePath& model_path,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
WARN_UNUSED_RESULT;

// Performs inferencing on the specified RankerExample. The example is first
// preprocessed using the model config. Returns false if a prediction could
// not be made (e.g. the model is not loaded yet).
bool Predict(RankerExample example,
std::vector<float>* prediction) WARN_UNUSED_RESULT;

// Performs inferencing on the specified feature vector. Returns false if
// a prediction could not be made.
bool Predict(const std::vector<float>& features,
std::vector<float>* prediction) WARN_UNUSED_RESULT;

// Validates that the loaded RankerModel is a valid BinaryClassifier model.
static RankerModelStatus ValidateModel(const RankerModel& model);

protected:
// Instantiates the inference module.
bool Initialize() override;

private:
friend class ClassifierPredictorTest;
ClassifierPredictor(const PredictorConfig& config);

NNClassifierModel model_;

DISALLOW_COPY_AND_ASSIGN(ClassifierPredictor);
};

} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_CLASSIFIER_PREDICTOR_H_
Loading

0 comments on commit 63366d9

Please sign in to comment.