Skip to content

Commit

Permalink
Add node_name to DeviceOption
Browse files Browse the repository at this point in the history
Summary: Allow for generalizing net transforms.

Reviewed By: Yangqing

Differential Revision: D5812140

fbshipit-source-id: e3f30acad362ae1f0614ee218d331b525710b88e
  • Loading branch information
aazzolini authored and facebook-github-bot committed Sep 13, 2017
1 parent 37af656 commit 68f3584
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 7 deletions.
9 changes: 9 additions & 0 deletions caffe2/core/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ NetBase::NetBase(
def->external_output().begin(),
def->external_output().end()),
name_(def->name()) {
// Check that node_name is empty for all ops
for (const OperatorDef& op : def->op()) {
if (op.has_device_option()) {
CAFFE_ENFORCE(
!op.device_option().has_node_name(),
"node_name must be empty for all operators at execution time.");
}
}

// Go through the operators and make sure that blobs are correctly made.
std::set<string> known_blobs(
external_input_.begin(), external_input_.end());
Expand Down
3 changes: 3 additions & 0 deletions caffe2/proto/caffe2.proto
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ message DeviceOption {
optional int32 cuda_gpu_id = 2;
// [general] The random seed to start the device random number generator with.
optional uint32 random_seed = 3;
// [general] What node this op should execute on.
// Used for net transformation purposes. Must be empty at execution time.
optional string node_name = 4;
}

// Operator Definition.
Expand Down
9 changes: 5 additions & 4 deletions caffe2/utils/proto_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
#include <cerrno>
#include <fstream>

#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>

#ifndef CAFFE2_USE_LITE_PROTO
#include "google/protobuf/text_format.h"
#include <google/protobuf/text_format.h>
#endif // !CAFFE2_USE_LITE_PROTO

#include "caffe2/core/logging.h"
Expand Down Expand Up @@ -42,7 +42,8 @@ std::string DeviceTypeName(const int32_t& d) {
bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs) {
return (
lhs.device_type() == rhs.device_type() &&
lhs.cuda_gpu_id() == rhs.cuda_gpu_id());
lhs.cuda_gpu_id() == rhs.cuda_gpu_id() &&
lhs.node_name() == rhs.node_name());
}

bool ReadStringFromFile(const char* filename, string* str) {
Expand Down
7 changes: 4 additions & 3 deletions caffe2/utils/proto_utils.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#ifndef CAFFE2_UTILS_PROTO_UTILS_H_
#define CAFFE2_UTILS_PROTO_UTILS_H_

#include "google/protobuf/message_lite.h"
#ifndef CAFFE2_USE_LITE_PROTO
#include "google/protobuf/message.h"
#ifdef CAFFE2_USE_LITE_PROTO
#include <google/protobuf/message_lite.h>
#else // CAFFE2_USE_LITE_PROTO
#include <google/protobuf/message.h>
#endif // !CAFFE2_USE_LITE_PROTO

#include "caffe2/core/logging.h"
Expand Down
17 changes: 17 additions & 0 deletions caffe2/utils/proto_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@

namespace caffe2 {

TEST(ProtoUtilsTest, IsSameDevice) {
DeviceOption a;
DeviceOption b;
EXPECT_TRUE(IsSameDevice(a, b));
a.set_node_name("my_node");
EXPECT_FALSE(IsSameDevice(a, b));
b.set_node_name("my_node");
EXPECT_TRUE(IsSameDevice(a, b));
b.set_cuda_gpu_id(2);
EXPECT_FALSE(IsSameDevice(a, b));
a.set_cuda_gpu_id(2);
EXPECT_TRUE(IsSameDevice(a, b));
a.set_device_type(DeviceType::CUDA);
b.set_device_type(DeviceType::CPU);
EXPECT_FALSE(IsSameDevice(a, b));
}

TEST(ProtoUtilsTest, SimpleReadWrite) {
string content("The quick brown fox jumps over the lazy dog.");
string name = std::tmpnam(nullptr);
Expand Down

0 comments on commit 68f3584

Please sign in to comment.