Skip to content

Commit

Permalink
Switch import/export to python printing (pytorch#14400)
Browse files Browse the repository at this point in the history
Summary:
Stacked on pytorch#14378, only look at the last commit.

This changes the way methods are defined in TorchScript archives to use
PythonPrint rather than ONNX protobufs.

It also updates torch.proto to directly document the tensor data
structure actually being serialized.

Notes:
* because PythonPrint prints all the methods at once per module, this
  removes MethodDef in favor of a single torchscript_area and a separate
  caffe2_graphs entry. Note that NetDef's already have method names,
  so there is no need or a separate method name entry.
* This switches cpp/pickle area to RecordRef (references to a file in
  the container format) since it is possible the data in these arenas
  may be large and not suited to json ouput.
* Removes 'annotations' -- annotations should be re-added on the first
  commit that actually has a practical use for them. In the current state
  it is unlikely they are representing the right information.
* Some expect files have changed because PythonPrint is preserving more
  debug name information for parameter names.
* MethodEncoder (the ONNX output format) has been deleted. There is still
  some cleanup possible combining EncoderBase and GraphEncode now that there
  is only a single pathway using EncoderBase.
* This incorporates the changes from pytorch#14397
  to define TensorDef
Pull Request resolved: pytorch#14400

Reviewed By: suo

Differential Revision: D13231800

Pulled By: zdevito

fbshipit-source-id: af5c1152d0bd6bca8b06c4703f59b161bb19f571
  • Loading branch information
zdevito authored and facebook-github-bot committed Nov 30, 2018
1 parent 2b7345b commit fd31eae
Show file tree
Hide file tree
Showing 31 changed files with 365 additions and 993 deletions.
90 changes: 37 additions & 53 deletions caffe2/proto/torch.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,63 @@ import "caffe2/proto/caffe2.proto";

package torch;

message ParameterDef {
message RecordRef {
optional string key = 1;
// size here refers to the uncompressed size, in bytes of the record
// this information also exists in the PyTorch container format data
// but is repeated here to make it possible to know size information without
// seeking to another record in the file.
optional int64 size = 2;
}

message TensorDef {
repeated int64 dims = 1;
optional int64 offset = 2;
repeated int64 strides = 3;
// whether we compute the gradient for the parameter
optional bool require_gradient = 1;
// whether this parameter is registered as buffer or not
optional bool is_buffer = 2;
optional bool requires_grad = 4;
optional caffe2.TensorProto.DataType data_type = 5;

// do not store tensor in parameter anymore, and retire field 3
// optional caffe2.TensorProto tensor = 3;
// the id in the tensor table, defined in TensorProto.name
optional string tensor_id = 5;
// objects other than tensors will be added here
optional RecordRef data = 6;

optional string name = 4;
// future: device options
}

message MethodDef {
// method name
// by default, we follow the naming convention below:
// 1) forward --> main method
// 2) init --> init method
optional string name = 1; // method name

// one of graph and torch_script must exist,
// if both exist, we reconstruct the graph from torch_script
optional caffe2.NetDef graph = 2;
optional string torch_script = 3;
// temporary place to store the methods of jit script modules
optional bytes onnx_proto = 101;

// inputs and outputs are inferred from graph or script
message ParameterDef {
// whether this parameter is registered as buffer or not
optional bool is_buffer = 1;

// the offset into the tensor table where this parameter is stored
optional int64 tensor_id = 2;

optional string name = 3;
}

message ModuleDef {
repeated ModuleDef submodules = 1;

// We suppose to store the modules in one of the following format:
// - methods (static graph or torch script)
// - pickle
// - cpp_arena
repeated MethodDef methods = 2;
optional RecordRef torchscript_arena = 2;

repeated caffe2.NetDef caffe2_nets = 3;

// because the old pickle modules may not be supported by torch_script,
// have to stored as pickle_arena at this moment.
optional bytes pickle_arena = 3;
optional RecordRef pickle_arena = 4;
// should be exposed by the Class Archive, so user can save
// module specific data which cannot be store in the graph or torch_script
optional bytes cpp_arena = 4;
optional RecordRef cpp_arena = 5;

// the parameters of this module
repeated ParameterDef parameters = 5;
repeated ParameterDef parameters = 6;

// the names of inputs and outputs of the module are inferred
// from the main method.

optional string name = 6;
optional string name = 7;

// whether apply the optimizations to this module, only applicable to
// script modules
optional bool optimize = 7;
optional bool optimize = 8;
}

enum ProtoVersion {
Expand All @@ -84,24 +83,9 @@ message ModelDef {
// put build version here
optional string producer_version = 4;

optional string name = 5;

// metadata
// - exporter - string (either "CAFFE2" or "PYTORCH"),
// to help the runtime understand who exports the model
// - debug_info - string
// for MetaNetDef:
// - project - string
// - model_class - string
// - internal_version - string
// - predictor_type - string
// - predictor_id - string
// - execute_plan - string
// - applicationSpecificInfo
// - publish_time - string
repeated caffe2.Argument annotations = 6;

// the table contains all the tensor information
// the tensor id is defined as TensorProto.name
repeated caffe2.TensorProto tensors = 7;
repeated TensorDef tensors = 5;

// future: add a way to provide additional meta-data
}
8 changes: 4 additions & 4 deletions test/expect/TestJit.test_broadcast_fusion_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
graph(%0 : Float(*, *)
%1 : Float(*)
%2 : Float(*)) {
%3 : Float(*, *) = prim::FusionGroup_0(%2, %0, %1)
graph(%x : Float(*, *)
%scale : Float(*)
%shift : Float(*)) {
%3 : Float(*, *) = prim::FusionGroup_0(%shift, %x, %scale)
return (%3);
}
with prim::FusionGroup_0 = graph(%0 : Float(*)
Expand Down
6 changes: 3 additions & 3 deletions test/expect/TestJit.test_concat_fusion_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
graph(%0 : Float(*, *)
%1 : Float(*, *)) {
%2 : Float(*, *) = prim::FusionGroup_0(%0, %1)
graph(%hx : Float(*, *)
%cx : Float(*, *)) {
%2 : Float(*, *) = prim::FusionGroup_0(%hx, %cx)
return (%2);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
Expand Down
18 changes: 9 additions & 9 deletions test/expect/TestJit.test_concat_fusion_invariant_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
graph(%0 : Float(*, *)
%1 : Float(*, *)
%2 : Float(*, *)) {
graph(%x : Float(*, *)
%y : Float(*, *)
%z : Float(*, *)) {
%3 : int = prim::Constant[value=1]()
%4 : Float(*, *) = prim::FusionGroup_0(%0, %1)
%5 : Float(*, *) = aten::add(%4, %2, %3)
%w : Float(*, *) = prim::FusionGroup_0(%x, %y)
%5 : Float(*, *) = aten::add(%w, %z, %3)
return (%5);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Float(*, *)) {
%2 : int = prim::Constant[value=1]()
%3 : Float(*, *) = aten::add(%0, %1, %2)
%x1 : Float(*, *) = aten::add(%0, %1, %2)
%4 : int = prim::Constant[value=1]()
%5 : Float(*, *) = aten::sub(%0, %1, %4)
%6 : Float(*, *) = prim::FusedConcat[dim=0](%3, %5)
return (%6);
%y1 : Float(*, *) = aten::sub(%0, %1, %4)
%w : Float(*, *) = prim::FusedConcat[dim=0](%x1, %y1)
return (%w);
}
12 changes: 6 additions & 6 deletions test/expect/TestJit.test_constant_prop_loop_constant.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ graph() {
%1 : bool = prim::Constant[value=1]()
%b.1 : int = prim::Constant[value=0]()
%3 : int = prim::Constant[value=9223372036854775807]()
%b.2 : int = prim::Constant[value=1]()
%b.4 : int = prim::Constant[value=2]()
%b.3 : int = prim::Loop(%3, %1, %b.1)
%4 : int = prim::Constant[value=1]()
%5 : int = prim::Constant[value=2]()
%b.2 : int = prim::Loop(%3, %1, %b.1)
block0(%7 : int, %8 : int) {
-> (%1, %b.2)
-> (%1, %4)
}
%b : int = prim::Loop(%3, %0, %b.3)
%b : int = prim::Loop(%3, %0, %b.2)
block0(%10 : int, %11 : int) {
-> (%0, %b.4)
-> (%0, %5)
}
return (%b);
}
6 changes: 3 additions & 3 deletions test/expect/TestJit.test_fuse_last_device_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
graph(%0 : Float(*)
%1 : Float(*)) {
%2 : Float(*) = prim::FusionGroup_0(%0, %1)
graph(%x : Float(*)
%y : Float(*)) {
%2 : Float(*) = prim::FusionGroup_0(%x, %y)
return (%2);
}
with prim::FusionGroup_0 = graph(%0 : Float(*)
Expand Down
6 changes: 3 additions & 3 deletions test/expect/TestJit.test_fusion_distribute_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
graph(%0 : Float(*, *)
%1 : Float(*, *)) {
%2 : Dynamic[] = prim::ListConstruct(%0, %1)
graph(%x : Float(*, *)
%y : Float(*, *)) {
%2 : Dynamic[] = prim::ListConstruct(%x, %y)
%3 : Dynamic[] = aten::broadcast_tensors(%2)
%4 : Dynamic, %5 : Dynamic = prim::ListUnpack(%3)
%6 : Float(*, *) = prim::FusionGroup_0(%5, %4)
Expand Down
3 changes: 2 additions & 1 deletion test/expect/TestJit.test_import_method.expect
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
def graph(self,
x: Tensor,
y: Tensor) -> Tensor:
return aten.add(aten.mul(x, 2), y, alpha=1)
_0 = torch.add(torch.mul(x, 2), y, alpha=1)
return _0
46 changes: 23 additions & 23 deletions test/expect/TestJit.test_lstm_fusion_concat_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
graph(%0 : Float(*, *)
%1 : Float(*, *)
%2 : Float(*, *)
%3 : Float(*, *)
%4 : Float(*, *)
%5 : Float(*)
%6 : Float(*)) {
%7 : Float(*, *) = aten::t(%3)
%8 : Float(*, *) = aten::mm(%0, %7)
%9 : Float(*, *) = aten::t(%4)
%10 : Float(*, *) = aten::mm(%1, %9)
%11 : Dynamic[] = prim::ListConstruct(%5, %8, %6, %10)
graph(%input_1 : Float(*, *)
%input : Float(*, *)
%cx : Float(*, *)
%weight_1 : Float(*, *)
%weight : Float(*, *)
%bias_1 : Float(*)
%bias : Float(*)) {
%7 : Float(*, *) = aten::t(%weight_1)
%8 : Float(*, *) = aten::mm(%input_1, %7)
%9 : Float(*, *) = aten::t(%weight)
%10 : Float(*, *) = aten::mm(%input, %9)
%11 : Dynamic[] = prim::ListConstruct(%bias_1, %8, %bias, %10)
%12 : Dynamic[] = aten::broadcast_tensors(%11)
%13 : Dynamic, %14 : Dynamic, %15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%12)
%17 : Float(*, *) = prim::FusionGroup_0(%2, %16, %15, %14, %13)
%17 : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
return (%17);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
Expand Down Expand Up @@ -48,16 +48,16 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%42 : Float(*, *) = aten::add(%34, %26, %41)
%43 : int = prim::Constant[value=1]()
%44 : Float(*, *) = aten::add(%36, %28, %43)
%45 : Float(*, *) = aten::sigmoid(%38)
%46 : Float(*, *) = aten::sigmoid(%40)
%47 : Float(*, *) = aten::tanh(%42)
%48 : Float(*, *) = aten::sigmoid(%44)
%49 : Float(*, *) = aten::mul(%46, %0)
%50 : Float(*, *) = aten::mul(%45, %47)
%ingate : Float(*, *) = aten::sigmoid(%38)
%forgetgate : Float(*, *) = aten::sigmoid(%40)
%cellgate : Float(*, *) = aten::tanh(%42)
%outgate : Float(*, *) = aten::sigmoid(%44)
%49 : Float(*, *) = aten::mul(%forgetgate, %0)
%50 : Float(*, *) = aten::mul(%ingate, %cellgate)
%51 : int = prim::Constant[value=1]()
%52 : Float(*, *) = aten::add(%49, %50, %51)
%53 : Float(*, *) = aten::tanh(%52)
%54 : Float(*, *) = aten::mul(%48, %53)
%55 : Float(*, *) = prim::FusedConcat[dim=0](%54, %52)
%cy : Float(*, *) = aten::add(%49, %50, %51)
%53 : Float(*, *) = aten::tanh(%cy)
%hy : Float(*, *) = aten::mul(%outgate, %53)
%55 : Float(*, *) = prim::FusedConcat[dim=0](%hy, %cy)
return (%55);
}
48 changes: 24 additions & 24 deletions test/expect/TestJit.test_lstm_fusion_cuda.expect
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
graph(%0 : Float(*, *)
%1 : Float(*, *)
%2 : Float(*, *)
%3 : Float(*, *)
%4 : Float(*, *)
%5 : Float(*)
%6 : Float(*)) {
%7 : Float(*, *) = aten::t(%3)
%8 : Float(*, *) = aten::mm(%0, %7)
%9 : Float(*, *) = aten::t(%4)
%10 : Float(*, *) = aten::mm(%1, %9)
%11 : Dynamic[] = prim::ListConstruct(%5, %8, %6, %10)
graph(%input_1 : Float(*, *)
%input : Float(*, *)
%cx : Float(*, *)
%weight_1 : Float(*, *)
%weight : Float(*, *)
%bias_1 : Float(*)
%bias : Float(*)) {
%7 : Float(*, *) = aten::t(%weight_1)
%8 : Float(*, *) = aten::mm(%input_1, %7)
%9 : Float(*, *) = aten::t(%weight)
%10 : Float(*, *) = aten::mm(%input, %9)
%11 : Dynamic[] = prim::ListConstruct(%bias_1, %8, %bias, %10)
%12 : Dynamic[] = aten::broadcast_tensors(%11)
%13 : Dynamic, %14 : Dynamic, %15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%12)
%17 : Float(*, *), %18 : Float(*, *) = prim::FusionGroup_0(%2, %16, %15, %14, %13)
return (%17, %18);
%17 : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
return (%17, %cy);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Dynamic
Expand Down Expand Up @@ -48,15 +48,15 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%42 : Float(*, *) = aten::add(%34, %26, %41)
%43 : int = prim::Constant[value=1]()
%44 : Float(*, *) = aten::add(%36, %28, %43)
%45 : Float(*, *) = aten::sigmoid(%38)
%46 : Float(*, *) = aten::sigmoid(%40)
%47 : Float(*, *) = aten::tanh(%42)
%48 : Float(*, *) = aten::sigmoid(%44)
%49 : Float(*, *) = aten::mul(%46, %0)
%50 : Float(*, *) = aten::mul(%45, %47)
%ingate : Float(*, *) = aten::sigmoid(%38)
%forgetgate : Float(*, *) = aten::sigmoid(%40)
%cellgate : Float(*, *) = aten::tanh(%42)
%outgate : Float(*, *) = aten::sigmoid(%44)
%49 : Float(*, *) = aten::mul(%forgetgate, %0)
%50 : Float(*, *) = aten::mul(%ingate, %cellgate)
%51 : int = prim::Constant[value=1]()
%52 : Float(*, *) = aten::add(%49, %50, %51)
%53 : Float(*, *) = aten::tanh(%52)
%54 : Float(*, *) = aten::mul(%48, %53)
return (%54, %52);
%cy : Float(*, *) = aten::add(%49, %50, %51)
%53 : Float(*, *) = aten::tanh(%cy)
%54 : Float(*, *) = aten::mul(%outgate, %53)
return (%54, %cy);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def graph(self,
y: Tensor) -> int:
x = annotate(List[int], [])
return aten.select(x, 0)
return torch.select(x, 0)
2 changes: 1 addition & 1 deletion test/expect/TestJit.test_pretty_printer-if_one.expect
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def graph(self,
a: Tensor,
b: Tensor) -> Tensor:
if bool(aten.lt(a, b)):
if bool(torch.lt(a, b)):
c = a
else:
c = b
Expand Down
2 changes: 1 addition & 1 deletion test/expect/TestJit.test_pretty_printer-if_test.expect
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def graph(self,
a: Tensor,
b: Tensor) -> Tensor:
if bool(aten.lt(a, b)):
if bool(torch.lt(a, b)):
c = b
else:
c = a
Expand Down
10 changes: 5 additions & 5 deletions test/expect/TestJit.test_pretty_printer-loop_use_test.expect
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
def graph(self,
y_1: Tensor) -> Tuple[Tensor, Tensor]:
x = aten.add(y_1, 1, 1)
z_1 = aten.add(x, 5, 1)
x = torch.add(y_1, 1, 1)
z_1 = torch.add(x, 5, 1)
y, z = y_1, z_1
_0 = bool(aten.lt(y_1, 8))
_0 = bool(torch.lt(y_1, 8))
while _0:
y_2 = aten.add_(y, 1, 1)
_0, y, z = bool(aten.lt(y_2, 8)), y_2, x
y_2 = torch.add_(y, 1, 1)
_0, y, z = bool(torch.lt(y_2, 8)), y_2, x
return x, z
Loading

0 comments on commit fd31eae

Please sign in to comment.