forked from chromium/chromium
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added predictor class for AssistRanker NN classifier.
Added a predictor class for the new NN classifier in AssistRanker. This makes it possible to use downloadable unquantized NN classifiers for inferencing. Bug: 907727 Change-Id: Ib1e5647e3792f10757c17968a6b0b9b198a8d620 Reviewed-on: https://chromium-review.googlesource.com/c/1399768 Commit-Queue: Jon Napper <napper@chromium.org> Reviewed-by: Charles . <charleszhao@chromium.org> Cr-Commit-Position: refs/heads/master@{#625104}
- Loading branch information
Jon Napper
authored and
Commit Bot
committed
Jan 23, 2019
1 parent
1d49df6
commit 63366d9
Showing
9 changed files
with
440 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
// Copyright 2018 The Chromium Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#include "components/assist_ranker/classifier_predictor.h" | ||
|
||
#include <memory> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "base/bind.h" | ||
#include "base/bind_helpers.h" | ||
#include "base/files/file_path.h" | ||
#include "components/assist_ranker/example_preprocessing.h" | ||
#include "components/assist_ranker/nn_classifier.h" | ||
#include "components/assist_ranker/proto/ranker_model.pb.h" | ||
#include "components/assist_ranker/ranker_model.h" | ||
#include "components/assist_ranker/ranker_model_loader_impl.h" | ||
#include "services/network/public/cpp/shared_url_loader_factory.h" | ||
|
||
namespace assist_ranker { | ||
|
||
ClassifierPredictor::ClassifierPredictor(const PredictorConfig& config) | ||
: BasePredictor(config){}; | ||
ClassifierPredictor::~ClassifierPredictor(){}; | ||
|
||
// static | ||
std::unique_ptr<ClassifierPredictor> ClassifierPredictor::Create( | ||
const PredictorConfig& config, | ||
const base::FilePath& model_path, | ||
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) { | ||
std::unique_ptr<ClassifierPredictor> predictor( | ||
new ClassifierPredictor(config)); | ||
if (!predictor->is_query_enabled()) { | ||
DVLOG(1) << "Query disabled, bypassing model loading."; | ||
return predictor; | ||
} | ||
const GURL& model_url = predictor->GetModelUrl(); | ||
DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName(); | ||
DVLOG(1) << "Model URL: " << model_url; | ||
auto model_loader = std::make_unique<RankerModelLoaderImpl>( | ||
base::BindRepeating(&ClassifierPredictor::ValidateModel), | ||
base::BindRepeating(&ClassifierPredictor::OnModelAvailable, | ||
base::Unretained(predictor.get())), | ||
url_loader_factory, model_path, model_url, config.uma_prefix); | ||
predictor->LoadModel(std::move(model_loader)); | ||
return predictor; | ||
} | ||
|
||
bool ClassifierPredictor::Predict(const std::vector<float>& features, | ||
std::vector<float>* prediction) { | ||
if (!IsReady()) { | ||
DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction."; | ||
return false; | ||
} | ||
|
||
*prediction = nn_classifier::Inference(model_, features); | ||
return true; | ||
} | ||
|
||
bool ClassifierPredictor::Predict(RankerExample example, | ||
std::vector<float>* prediction) { | ||
if (!IsReady()) { | ||
DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction."; | ||
return false; | ||
} | ||
|
||
if (!model_.has_preprocessor_config()) { | ||
DVLOG(1) << "No preprocessor config specified."; | ||
return false; | ||
} | ||
|
||
const int preprocessor_error = | ||
ExamplePreprocessor::Process(model_.preprocessor_config(), &example); | ||
|
||
// It is okay to ignore cases where there is an extra feature that is not in | ||
// the config. | ||
if (preprocessor_error != ExamplePreprocessor::kSuccess && | ||
preprocessor_error != ExamplePreprocessor::kNoFeatureIndexFound) { | ||
DVLOG(1) << "Preprocessing error " << preprocessor_error; | ||
return false; | ||
} | ||
|
||
const auto& vec = | ||
example.features() | ||
.at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName) | ||
.float_list() | ||
.float_value(); | ||
const std::vector<float> features(vec.begin(), vec.end()); | ||
return Predict(features, prediction); | ||
} | ||
|
||
// static | ||
RankerModelStatus ClassifierPredictor::ValidateModel(const RankerModel& model) { | ||
if (model.proto().model_case() != RankerModelProto::kNnClassifier) { | ||
DVLOG(0) << "Model is incompatible."; | ||
return RankerModelStatus::INCOMPATIBLE; | ||
} | ||
return nn_classifier::Validate(model.proto().nn_classifier()) | ||
? RankerModelStatus::OK | ||
: RankerModelStatus::INCOMPATIBLE; | ||
} | ||
|
||
bool ClassifierPredictor::Initialize() { | ||
if (ranker_model_->proto().model_case() == RankerModelProto::kNnClassifier) { | ||
model_ = ranker_model_->proto().nn_classifier(); | ||
return true; | ||
} | ||
|
||
DVLOG(0) << "Could not initialize inference module."; | ||
return false; | ||
} | ||
|
||
} // namespace assist_ranker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
// Copyright 2018 The Chromium Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style license that can be | ||
// found in the LICENSE file. | ||
|
||
#ifndef COMPONENTS_ASSIST_RANKER_CLASSIFIER_PREDICTOR_H_ | ||
#define COMPONENTS_ASSIST_RANKER_CLASSIFIER_PREDICTOR_H_ | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include "base/compiler_specific.h" | ||
#include "components/assist_ranker/base_predictor.h" | ||
#include "components/assist_ranker/nn_classifier.h" | ||
#include "components/assist_ranker/proto/ranker_example.pb.h" | ||
|
||
namespace base { | ||
class FilePath; | ||
} | ||
|
||
namespace network { | ||
class SharedURLLoaderFactory; | ||
} | ||
|
||
namespace assist_ranker { | ||
|
||
// Predictor class for single-layer neural network models. | ||
class ClassifierPredictor : public BasePredictor { | ||
public: | ||
~ClassifierPredictor() override; | ||
|
||
// Returns an new predictor instance with the given |config| and initialize | ||
// its model loader. The |request_context getter| is passed to the | ||
// predictor's model_loader which holds it as scoped_refptr. | ||
static std::unique_ptr<ClassifierPredictor> Create( | ||
const PredictorConfig& config, | ||
const base::FilePath& model_path, | ||
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) | ||
WARN_UNUSED_RESULT; | ||
|
||
// Performs inferencing on the specified RankerExample. The example is first | ||
// preprocessed using the model config. Returns false if a prediction could | ||
// not be made (e.g. the model is not loaded yet). | ||
bool Predict(RankerExample example, | ||
std::vector<float>* prediction) WARN_UNUSED_RESULT; | ||
|
||
// Performs inferencing on the specified feature vector. Returns false if | ||
// a prediction could not be made. | ||
bool Predict(const std::vector<float>& features, | ||
std::vector<float>* prediction) WARN_UNUSED_RESULT; | ||
|
||
// Validates that the loaded RankerModel is a valid BinaryClassifier model. | ||
static RankerModelStatus ValidateModel(const RankerModel& model); | ||
|
||
protected: | ||
// Instantiates the inference module. | ||
bool Initialize() override; | ||
|
||
private: | ||
friend class ClassifierPredictorTest; | ||
ClassifierPredictor(const PredictorConfig& config); | ||
|
||
NNClassifierModel model_; | ||
|
||
DISALLOW_COPY_AND_ASSIGN(ClassifierPredictor); | ||
}; | ||
|
||
} // namespace assist_ranker | ||
#endif // COMPONENTS_ASSIST_RANKER_CLASSIFIER_PREDICTOR_H_ |
Oops, something went wrong.