forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
proto_utils.h
383 lines (325 loc) · 13 KB
/
proto_utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
#ifndef CAFFE2_UTILS_PROTO_UTILS_H_
#define CAFFE2_UTILS_PROTO_UTILS_H_
#ifdef CAFFE2_USE_LITE_PROTO
#include <google/protobuf/message_lite.h>
#else // CAFFE2_USE_LITE_PROTO
#include <google/protobuf/message.h>
#endif // !CAFFE2_USE_LITE_PROTO
#include <c10/util/Logging.h>
#include <c10/util/string_view.h>
#include <c10/util/irange.h>
#include "caffe2/utils/proto_wrap.h"
#include "caffe2/proto/caffe2_pb.h"
#ifndef C10_ANDROID
#define CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
#define CAFFE2_ARG_MAP_FIND(map, key) map.find(key)
#else
#define CAFFE2_ARG_MAP_FIND(map, key) map.find(std::string(key))
#endif
namespace caffe2 {
using std::string;
using ::google::protobuf::MessageLite;
// A wrapper function to return device name string for use in blob serialization
// / deserialization. This should have one to one correspondence with
// caffe2/proto/caffe2.proto: enum DeviceType.
//
// Note that we can't use DeviceType_Name, because that is only available in
// protobuf-full, and some platforms (like mobile) may want to use
// protobuf-lite instead.
TORCH_API std::string DeviceTypeName(const int32_t& d);
TORCH_API int DeviceId(const DeviceOption& option);
// Returns if the two DeviceOptions are pointing to the same device.
TORCH_API bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs);
TORCH_API bool IsCPUDeviceType(int device_type);
TORCH_API bool IsGPUDeviceType(int device_type);
// Common interfaces that reads file contents into a string.
TORCH_API bool ReadStringFromFile(const char* filename, string* str);
TORCH_API bool WriteStringToFile(const string& str, const char* filename);
// Common interfaces that are supported by both lite and full protobuf.
TORCH_API bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto);
inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) {
return ReadProtoFromBinaryFile(filename.c_str(), proto);
}
TORCH_API void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename);
inline void WriteProtoToBinaryFile(const MessageLite& proto,
const string& filename) {
return WriteProtoToBinaryFile(proto, filename.c_str());
}
#ifdef CAFFE2_USE_LITE_PROTO
namespace TextFormat {
inline bool ParseFromString(const string& spec, MessageLite* proto) {
LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers.";
return false;
}
} // namespace TextFormat
TORCH_API string ProtoDebugString(const MessageLite& proto);
TORCH_API bool ParseProtoFromLargeString(const string& str, MessageLite* proto);
// Text format MessageLite wrappers: these functions do nothing but just
// allowing things to compile. It will produce a runtime error if you are using
// MessageLite but still want text support.
inline bool ReadProtoFromTextFile(
const char* /*filename*/,
MessageLite* /*proto*/) {
LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers.";
return false; // Just to suppress compiler warning.
}
inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) {
return ReadProtoFromTextFile(filename.c_str(), proto);
}
inline void WriteProtoToTextFile(
const MessageLite& /*proto*/,
const char* /*filename*/,
bool throwIfError = true) {
LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers.";
}
inline void WriteProtoToTextFile(const MessageLite& proto,
const string& filename,
bool throwIfError = true) {
return WriteProtoToTextFile(proto, filename.c_str(), throwIfError);
}
inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) {
return (ReadProtoFromBinaryFile(filename, proto) ||
ReadProtoFromTextFile(filename, proto));
}
inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
return ReadProtoFromFile(filename.c_str(), proto);
}
#else // CAFFE2_USE_LITE_PROTO
using ::google::protobuf::Message;
namespace TextFormat {
TORCH_API bool ParseFromString(const string& spec, Message* proto);
} // namespace TextFormat
TORCH_API string ProtoDebugString(const Message& proto);
TORCH_API bool ParseProtoFromLargeString(const string& str, Message* proto);
TORCH_API bool ReadProtoFromTextFile(const char* filename, Message* proto);
inline bool ReadProtoFromTextFile(const string filename, Message* proto) {
return ReadProtoFromTextFile(filename.c_str(), proto);
}
TORCH_API void WriteProtoToTextFile(const Message& proto, const char* filename, bool throwIfError = true);
inline void WriteProtoToTextFile(const Message& proto, const string& filename, bool throwIfError = true) {
return WriteProtoToTextFile(proto, filename.c_str(), throwIfError);
}
// Read Proto from a file, letting the code figure out if it is text or binary.
inline bool ReadProtoFromFile(const char* filename, Message* proto) {
return (ReadProtoFromBinaryFile(filename, proto) ||
ReadProtoFromTextFile(filename, proto));
}
inline bool ReadProtoFromFile(const string& filename, Message* proto) {
return ReadProtoFromFile(filename.c_str(), proto);
}
#endif // CAFFE2_USE_LITE_PROTO
template <
class IterableInputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>,
class IterableArgs = std::initializer_list<Argument>>
OperatorDef CreateOperatorDef(
const string& type,
const string& name,
const IterableInputs& inputs,
const IterableOutputs& outputs,
const IterableArgs& args,
const DeviceOption& device_option = DeviceOption(),
const string& engine = "") {
OperatorDef def;
def.set_type(type);
def.set_name(name);
for (const string& in : inputs) {
def.add_input(in);
}
for (const string& out : outputs) {
def.add_output(out);
}
for (const Argument& arg : args) {
def.add_arg()->CopyFrom(arg);
}
if (device_option.has_device_type()) {
def.mutable_device_option()->CopyFrom(device_option);
}
if (engine.size()) {
def.set_engine(engine);
}
return def;
}
// A simplified version compared to the full CreateOperator, if you do not need
// to specify args.
template <
class IterableInputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>>
inline OperatorDef CreateOperatorDef(
const string& type,
const string& name,
const IterableInputs& inputs,
const IterableOutputs& outputs,
const DeviceOption& device_option = DeviceOption(),
const string& engine = "") {
return CreateOperatorDef(
type,
name,
inputs,
outputs,
std::vector<Argument>(),
device_option,
engine);
}
TORCH_API bool HasOutput(const OperatorDef& op, const std::string& output);
TORCH_API bool HasInput(const OperatorDef& op, const std::string& input);
/**
* @brief A helper class to index into arguments.
*
* This helper helps us to more easily index into a set of arguments
* that are present in the operator. To save memory, the argument helper
* does not copy the operator def, so one would need to make sure that the
* lifetime of the OperatorDef object outlives that of the ArgumentHelper.
*/
class C10_EXPORT ArgumentHelper {
public:
template <typename Def>
static bool HasArgument(const Def& def, c10::string_view name) {
return ArgumentHelper(def).HasArgument(name);
}
template <typename Def, typename T>
static T GetSingleArgument(
const Def& def,
c10::string_view name,
const T& default_value) {
return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
}
template <typename Def, typename T>
static bool HasSingleArgumentOfType(const Def& def, c10::string_view name) {
return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
}
template <typename Def, typename T>
static std::vector<T> GetRepeatedArgument(
const Def& def,
c10::string_view name,
const std::vector<T>& default_value = std::vector<T>()) {
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
}
template <typename Def, typename MessageType>
static MessageType GetMessageArgument(const Def& def, c10::string_view name) {
return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
}
template <typename Def, typename MessageType>
static std::vector<MessageType> GetRepeatedMessageArgument(
const Def& def,
c10::string_view name) {
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
}
template <typename Def>
static bool RemoveArgument(Def& def, int index) {
if (index >= def.arg_size()) {
return false;
}
if (index < def.arg_size() - 1) {
def.mutable_arg()->SwapElements(index, def.arg_size() - 1);
}
def.mutable_arg()->RemoveLast();
return true;
}
explicit ArgumentHelper(const OperatorDef& def);
explicit ArgumentHelper(const NetDef& netdef);
bool HasArgument(c10::string_view name) const;
template <typename T>
T GetSingleArgument(c10::string_view name, const T& default_value) const;
template <typename T>
bool HasSingleArgumentOfType(c10::string_view name) const;
template <typename T>
std::vector<T> GetRepeatedArgument(
c10::string_view name,
const std::vector<T>& default_value = std::vector<T>()) const;
template <typename MessageType>
MessageType GetMessageArgument(c10::string_view name) const {
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);
CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name);
MessageType message;
if (it->second.has_s()) {
CAFFE_ENFORCE(
message.ParseFromString(it->second.s()),
"Failed to parse content from the string");
} else {
VLOG(1) << "Return empty message for parameter " << name;
}
return message;
}
template <typename MessageType>
std::vector<MessageType> GetRepeatedMessageArgument(c10::string_view name) const {
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name);
CAFFE_ENFORCE(it != arg_map_.end(), "Cannot find parameter named ", name);
std::vector<MessageType> messages(it->second.strings_size());
for (int i = 0; i < messages.size(); ++i) {
CAFFE_ENFORCE(
messages[i].ParseFromString(it->second.strings(i)),
"Failed to parse content from the string");
}
return messages;
}
private:
std::map<string, Argument
#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
, std::less<>
#endif
> arg_map_;
};
// **** Arguments Utils *****
// Helper methods to get an argument from OperatorDef or NetDef given argument
// name. Throws if argument does not exist.
TORCH_API const Argument& GetArgument(const OperatorDef& def, c10::string_view name);
TORCH_API const Argument& GetArgument(const NetDef& def, c10::string_view name);
// Helper methods to get an argument from OperatorDef or NetDef given argument
// name. Returns nullptr if argument does not exist.
TORCH_API const Argument* GetArgumentPtr(const OperatorDef& def, c10::string_view name);
TORCH_API const Argument* GetArgumentPtr(const NetDef& def, c10::string_view name);
// Helper methods to query a boolean argument flag from OperatorDef or NetDef
// given argument name. If argument does not exist, return default value.
// Throws if argument exists but the type is not boolean.
TORCH_API bool GetFlagArgument(
const OperatorDef& def,
c10::string_view name,
bool default_value = false);
TORCH_API bool GetFlagArgument(
const NetDef& def,
c10::string_view name,
bool default_value = false);
TORCH_API Argument* GetMutableArgument(
const string& name,
const bool create_if_missing,
OperatorDef* def);
TORCH_API Argument* GetMutableArgument(
const string& name,
const bool create_if_missing,
NetDef* def);
template <typename T>
TORCH_API Argument MakeArgument(const string& name, const T& value);
template <typename T, typename Def>
inline void AddArgument(const string& name, const T& value, Def* def) {
GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value));
}
// **** End Arguments Utils *****
bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) {
return IsSameDevice(dl, dr);
}
// Given a net, modify the external inputs/outputs if necessary so that
// the following conditions are met
// - No duplicate external inputs
// - No duplicate external outputs
// - Going through list of ops in order, all op inputs must be outputs
// from other ops, or registered as external inputs.
// - All external outputs must be outputs of some operators.
TORCH_API void cleanupExternalInputsAndOutputs(NetDef* net);
} // namespace caffe2
namespace std {
template <>
struct hash<caffe2::DeviceOption> {
typedef caffe2::DeviceOption argument_type;
typedef std::size_t result_type;
result_type operator()(argument_type const& device_option) const {
std::string serialized;
CAFFE_ENFORCE(device_option.SerializeToString(&serialized));
return std::hash<std::string>{}(serialized);
}
};
} // namespace std
#endif // CAFFE2_UTILS_PROTO_UTILS_H_