Skip to content

Commit

Permalink
[caffe2] Add default values to speed_benchmark args (pytorch#6210)
Browse files Browse the repository at this point in the history
  • Loading branch information
hlu1 authored and Yangqing committed Apr 3, 2018
1 parent fd2e7cb commit 2e156f3
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions binaries/speed_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ CAFFE2_DEFINE_string(
"separated numbers. If multiple input needed, use "
"semicolon to separate the dimension of different "
"tensors.");
CAFFE2_DEFINE_string(
input_type,
"", "Input type (uint8_t/float)");
CAFFE2_DEFINE_string(input_type, "", "Input type (uint8_t/float)");
CAFFE2_DEFINE_string(
output,
"",
Expand Down Expand Up @@ -103,9 +101,13 @@ int main(int argc, char** argv) {
workspace->CreateBlob(input_names[i])->Deserialize(blob_proto);
}
} else if (caffe2::FLAGS_input_dims.size() || caffe2::FLAGS_input_type.size()) {
CAFFE_ENFORCE_NE(0, caffe2::FLAGS_input_dims.size(),
CAFFE_ENFORCE_GE(
caffe2::FLAGS_input_dims.size(),
0,
"Input dims must be specified when input tensors are used.");
CAFFE_ENFORCE_NE(0, caffe2::FLAGS_input_type.size(),
CAFFE_ENFORCE_GE(
caffe2::FLAGS_input_type.size(),
0,
"Input type must be specified when input tensors are used.");

vector<string> input_dims_list =
Expand Down Expand Up @@ -150,6 +152,9 @@ int main(int argc, char** argv) {

// Run main network.
CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_net, &net_def));
if (!net_def.has_name()) {
net_def.set_name("benchmark");
}
// force changing engine and algo
if (caffe2::FLAGS_force_engine) {
LOG(INFO) << "force engine be: " << caffe2::FLAGS_engine;
Expand All @@ -167,6 +172,7 @@ int main(int argc, char** argv) {
}
caffe2::NetBase* net = workspace->CreateNet(net_def);
CHECK_NOTNULL(net);
CAFFE_ENFORCE(net->Run());
net->TEST_Benchmark(
caffe2::FLAGS_warmup, caffe2::FLAGS_iter, caffe2::FLAGS_run_individual);

Expand Down

0 comments on commit 2e156f3

Please sign in to comment.