diff --git a/binaries/speed_benchmark.cc b/binaries/speed_benchmark.cc index f421928c7bcc9..627492f0b6102 100644 --- a/binaries/speed_benchmark.cc +++ b/binaries/speed_benchmark.cc @@ -48,9 +48,7 @@ CAFFE2_DEFINE_string( "separated numbers. If multiple input needed, use " "semicolon to separate the dimension of different " "tensors."); -CAFFE2_DEFINE_string( - input_type, - "", "Input type (uint8_t/float)"); +CAFFE2_DEFINE_string(input_type, "", "Input type (uint8_t/float)"); CAFFE2_DEFINE_string( output, "", @@ -103,9 +101,13 @@ int main(int argc, char** argv) { workspace->CreateBlob(input_names[i])->Deserialize(blob_proto); } } else if (caffe2::FLAGS_input_dims.size() || caffe2::FLAGS_input_type.size()) { - CAFFE_ENFORCE_NE(0, caffe2::FLAGS_input_dims.size(), + CAFFE_ENFORCE_GE( + caffe2::FLAGS_input_dims.size(), + 0, "Input dims must be specified when input tensors are used."); - CAFFE_ENFORCE_NE(0, caffe2::FLAGS_input_type.size(), + CAFFE_ENFORCE_GE( + caffe2::FLAGS_input_type.size(), + 0, "Input type must be specified when input tensors are used."); vector input_dims_list = @@ -150,6 +152,9 @@ int main(int argc, char** argv) { // Run main network. CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_net, &net_def)); + if (!net_def.has_name()) { + net_def.set_name("benchmark"); + } // force changing engine and algo if (caffe2::FLAGS_force_engine) { LOG(INFO) << "force engine be: " << caffe2::FLAGS_engine; @@ -167,6 +172,7 @@ int main(int argc, char** argv) { } caffe2::NetBase* net = workspace->CreateNet(net_def); CHECK_NOTNULL(net); + CAFFE_ENFORCE(net->Run()); net->TEST_Benchmark( caffe2::FLAGS_warmup, caffe2::FLAGS_iter, caffe2::FLAGS_run_individual);