Skip to content

Commit

Permalink
Get pretty printer ready for use as a serialization format (pytorch#1…
Browse files Browse the repository at this point in the history
…3616)

Summary:
Get pretty printer ready for use as a serialization format

This PR adds a bunch of functionality to the pretty printer (now called python_printer to reflect
the fact that it will be used to output valid python source). The idea is to get the printer
ready for use as serialization format.  This PR does not have tests beyond what the pretty
printer already had. PRs stacked on this one will do round-trip export/import to test this functionality more robustly.

Notes:
* PythonPrinter is an evolution of the original pretty printer. However, much of it has changed so it is best just to
  read it as a new implementation. Trying to correlate it to the original implementation is probably not much help.
* The printer tries to get reasonably close to how the original function was likely written, such as
  writing expressions rather than making intermediates when possible. We may decide to turn this off
  for the actual serialization, but it is useful for pretty printing.
* tensor field access was changed so that prim::device and family have schema
* fixed a bug in the compiler where setUniqueName gets called even when a value already has one.
  this sometimes assigned really poor names to graph inputs
* Graph::insert gains an optional range argument to make range-preserving inserts easier.
* prim:: ops that can have schema now have schema. This is because when we parse them back in,
  we will need the schema to correctly set their output types.
* there is code in the python printer to complain if you try to add a prim op and do not update the printer.
* BuiltinModule is generalized to take an operator namespace and a version number for work in future commits.
Pull Request resolved: pytorch#13616

Reviewed By: goldsborough

Differential Revision: D13008252

Pulled By: zdevito

fbshipit-source-id: 32b33bc6410d6ca1c6f02bd6e050f8d5eea32083
  • Loading branch information
zdevito authored and facebook-github-bot committed Nov 12, 2018
1 parent b7a7ab3 commit aef9e76
Show file tree
Hide file tree
Showing 35 changed files with 799 additions and 565 deletions.
7 changes: 4 additions & 3 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,16 @@ namespace c10 {
_(prim, IntToFloat) \
_(prim, FloatToInt) \
_(prim, StringToFloat) \
_(prim, TensorDevice) \
_(prim, TensorDType) \
_(prim, TensorShape) \
_(prim, device) \
_(prim, dtype) \
_(prim, shape) \
_(prim, AutogradAdd) \
_(prim, GradOf) \
_(prim, AnyDefined) \
_(prim, FusedConcat) \
_(prim, ConstantChunk) \
_(prim, NoneGenerator) \
_(prim, MMTreeReduce) \
_(aten, floordiv) \
_(prim, fork) \
_(prim, RaiseException) \
Expand Down
22 changes: 11 additions & 11 deletions test/expect/TestBatched.test_if_else.expect
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,33 @@ graph(%a.1_data : Dynamic
%dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%16 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::TensorToNum(%16)
%data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
%data.5 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%21 : bool = prim::Constant[value=1]()
%22 : int = prim::Constant[value=1]()
%23 : Dynamic = aten::type_as(%8, %7)
%cond_mask.1 : Dynamic = aten::mul(%7, %23)
%25 : int = aten::dim(%cond_mask.1)
%data.2 : Dynamic = aten::mul(%7, %23)
%25 : int = aten::dim(%data.2)
%26 : bool = aten::eq(%25, %22)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%26)
block0() {
%30 : int = aten::dim(%data.1)
%31 : int = aten::sub(%30, %22)
%data.3 : Dynamic = prim::Loop(%31, %21, %cond_mask.1)
%data.4 : Dynamic = prim::Loop(%31, %21, %data.2)
block0(%_ : int, %34 : Dynamic) {
%35 : int = aten::dim(%34)
%data.2 : Dynamic = aten::unsqueeze(%34, %35)
-> (%21, %data.2)
%data.3 : Dynamic = aten::unsqueeze(%34, %35)
-> (%21, %data.3)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
-> (%cond_data.1, %cond_mask.2, %data.3)
%cond_data.1 : Dynamic = aten::expand_as(%data.4, %data.1)
%cond_mask.1 : Dynamic = aten::expand_as(%data.4, %mask.1)
-> (%cond_data.1, %cond_mask.1, %data.4)
}
block1() {
-> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
-> (%data.2, %data.2, %data.2)
}
%res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4)
%res_data : Dynamic = aten::where(%cond_data, %data.1, %data.5)
%res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask)
%res_dims : Dynamic = aten::__or__(%dims.1, %dims)
return (%res_data, %res_mask, %res_dims);
Expand Down
22 changes: 11 additions & 11 deletions test/expect/TestBatched.test_if_else_with_scalar.expect
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,33 @@ graph(%a.1_data : Dynamic
%dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%17 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::TensorToNum(%17)
%data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
%data.5 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha)
%mask : Dynamic = aten::mul(%a.1_mask, %b_mask)
%dims : Dynamic = aten::__or__(%a.1_dims, %b_dims)
%22 : bool = prim::Constant[value=1]()
%23 : int = prim::Constant[value=1]()
%24 : Dynamic = aten::type_as(%a.1_mask, %10)
%cond_mask.1 : Dynamic = aten::mul(%10, %24)
%26 : int = aten::dim(%cond_mask.1)
%data.2 : Dynamic = aten::mul(%10, %24)
%26 : int = aten::dim(%data.2)
%27 : bool = aten::eq(%26, %23)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%27)
block0() {
%31 : int = aten::dim(%data.1)
%32 : int = aten::sub(%31, %23)
%data.3 : Dynamic = prim::Loop(%32, %22, %cond_mask.1)
%data.4 : Dynamic = prim::Loop(%32, %22, %data.2)
block0(%_ : int, %35 : Dynamic) {
%36 : int = aten::dim(%35)
%data.2 : Dynamic = aten::unsqueeze(%35, %36)
-> (%22, %data.2)
%data.3 : Dynamic = aten::unsqueeze(%35, %36)
-> (%22, %data.3)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1)
-> (%cond_data.1, %cond_mask.2, %data.3)
%cond_data.1 : Dynamic = aten::expand_as(%data.4, %data.1)
%cond_mask.1 : Dynamic = aten::expand_as(%data.4, %mask.1)
-> (%cond_data.1, %cond_mask.1, %data.4)
}
block1() {
-> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
-> (%data.2, %data.2, %data.2)
}
%res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4)
%res_data : Dynamic = aten::where(%cond_data, %data.1, %data.5)
%res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask)
%res_dims : Dynamic = aten::__or__(%dims.1, %dims)
return (%res_data, %res_mask, %res_dims);
Expand Down
18 changes: 9 additions & 9 deletions test/expect/TestBatched.test_if_noelse.expect
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@ graph(%a.1_data : Dynamic
%16 : bool = prim::Constant[value=1]()
%17 : int = prim::Constant[value=1]()
%18 : Dynamic = aten::type_as(%8, %7)
%cond_mask.1 : Dynamic = aten::mul(%7, %18)
%20 : int = aten::dim(%cond_mask.1)
%data.2 : Dynamic = aten::mul(%7, %18)
%20 : int = aten::dim(%data.2)
%21 : bool = aten::eq(%20, %17)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%21)
block0() {
%25 : int = aten::dim(%data.1)
%26 : int = aten::sub(%25, %17)
%data.3 : Dynamic = prim::Loop(%26, %16, %cond_mask.1)
%data.4 : Dynamic = prim::Loop(%26, %16, %data.2)
block0(%_ : int, %29 : Dynamic) {
%30 : int = aten::dim(%29)
%data.2 : Dynamic = aten::unsqueeze(%29, %30)
-> (%16, %data.2)
%data.3 : Dynamic = aten::unsqueeze(%29, %30)
-> (%16, %data.3)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
-> (%cond_data.1, %cond_mask.2, %data.3)
%cond_data.1 : Dynamic = aten::expand_as(%data.4, %data.1)
%cond_mask.1 : Dynamic = aten::expand_as(%data.4, %mask)
-> (%cond_data.1, %cond_mask.1, %data.4)
}
block1() {
-> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
-> (%data.2, %data.2, %data.2)
}
%res_data : Dynamic = aten::where(%cond_data, %data.1, %a.1_data)
%res_mask : Dynamic = aten::where(%cond_mask, %mask, %a.1_mask)
Expand Down
18 changes: 9 additions & 9 deletions test/expect/TestBatched.test_if_noelse_with_scalar.expect
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,25 @@ graph(%a.1_data : Dynamic
%17 : bool = prim::Constant[value=1]()
%18 : int = prim::Constant[value=1]()
%19 : Dynamic = aten::type_as(%a.1_mask, %10)
%cond_mask.1 : Dynamic = aten::mul(%10, %19)
%21 : int = aten::dim(%cond_mask.1)
%data.2 : Dynamic = aten::mul(%10, %19)
%21 : int = aten::dim(%data.2)
%22 : bool = aten::eq(%21, %18)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%22)
block0() {
%26 : int = aten::dim(%data.1)
%27 : int = aten::sub(%26, %18)
%data.3 : Dynamic = prim::Loop(%27, %17, %cond_mask.1)
%data.4 : Dynamic = prim::Loop(%27, %17, %data.2)
block0(%_ : int, %30 : Dynamic) {
%31 : int = aten::dim(%30)
%data.2 : Dynamic = aten::unsqueeze(%30, %31)
-> (%17, %data.2)
%data.3 : Dynamic = aten::unsqueeze(%30, %31)
-> (%17, %data.3)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
-> (%cond_data.1, %cond_mask.2, %data.3)
%cond_data.1 : Dynamic = aten::expand_as(%data.4, %data.1)
%cond_mask.1 : Dynamic = aten::expand_as(%data.4, %mask)
-> (%cond_data.1, %cond_mask.1, %data.4)
}
block1() {
-> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
-> (%data.2, %data.2, %data.2)
}
%res_data : Dynamic = aten::where(%cond_data, %data.1, %a.1_data)
%res_mask : Dynamic = aten::where(%cond_mask, %mask, %a.1_mask)
Expand Down
24 changes: 12 additions & 12 deletions test/expect/TestBatched.test_while.expect
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ graph(%a.1_data : Dynamic
%b_mask : Dynamic
%b_dims : Dynamic) {
%6 : int = prim::Constant[value=1]()
%7 : int = prim::Constant[value=2147483647]()
%7 : int = prim::Constant[value=9223372036854775807]()
%8 : Dynamic = aten::gt(%a.1_data, %b_data)
%9 : Dynamic = aten::mul(%a.1_mask, %b_mask)
%10 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
Expand All @@ -16,7 +16,7 @@ graph(%a.1_data : Dynamic
%15 : Dynamic = aten::gt(%14, %12)
%16 : bool = prim::TensorToBool(%15)
%17 : Dynamic, %18 : Dynamic, %19 : Dynamic, %a : Dynamic, %21 : Dynamic, %22 : Dynamic = prim::Loop(%7, %16, %8, %9, %10, %a.1_data, %a.1_mask, %a.1_dims)
block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.2 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
%30 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::TensorToNum(%30)
%data.1 : Dynamic = aten::sub(%6_data, %b_data, %alpha)
Expand All @@ -28,26 +28,26 @@ graph(%a.1_data : Dynamic
%38 : bool = prim::TensorToBool(%35)
%39 : bool = prim::Constant[value=1]()
%40 : int = prim::Constant[value=1]()
%41 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2)
%cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %41)
%43 : int = aten::dim(%cond_mask.1)
%41 : Dynamic = aten::type_as(%cond_mask.2, %cond_data.2)
%data.2 : Dynamic = aten::mul(%cond_data.2, %41)
%43 : int = aten::dim(%data.2)
%44 : bool = aten::eq(%43, %40)
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%44)
block0() {
%48 : int = aten::dim(%data.1)
%49 : int = aten::sub(%48, %40)
%data.3 : Dynamic = prim::Loop(%49, %39, %cond_mask.1)
%data.4 : Dynamic = prim::Loop(%49, %39, %data.2)
block0(%_ : int, %52 : Dynamic) {
%53 : int = aten::dim(%52)
%data.2 : Dynamic = aten::unsqueeze(%52, %53)
-> (%39, %data.2)
%data.3 : Dynamic = aten::unsqueeze(%52, %53)
-> (%39, %data.3)
}
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
%cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask)
-> (%cond_data.1, %cond_mask.2, %data.3)
%cond_data.1 : Dynamic = aten::expand_as(%data.4, %data.1)
%cond_mask.1 : Dynamic = aten::expand_as(%data.4, %mask)
-> (%cond_data.1, %cond_mask.1, %data.4)
}
block1() {
-> (%cond_mask.1, %cond_mask.1, %cond_mask.1)
-> (%data.2, %data.2, %data.2)
}
%res_data : Dynamic = aten::where(%cond_data, %data.1, %6_data)
%res_mask : Dynamic = aten::where(%cond_mask, %mask, %6_mask)
Expand Down
20 changes: 10 additions & 10 deletions test/expect/TestJit.test_constant_prop_if_constant.expect
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
graph(%a : Dynamic
%b : Dynamic) {
%c2.1 : int = prim::Constant[value=1]()
%c0.1 : int = prim::Constant[value=1]()
%3 : bool = prim::TensorToBool(%a)
%c0.4 : int, %c1 : int = prim::If(%3)
%c0.5 : int, %c1 : int = prim::If(%3)
block0() {
%6 : bool = prim::TensorToBool(%b)
%c0.3 : int = prim::If(%6)
%c0.4 : int = prim::If(%6)
block0() {
%8 : int = prim::Constant[value=2]()
-> (%8)
}
block1() {
-> (%c2.1)
-> (%c0.1)
}
-> (%c0.3, %c2.1)
-> (%c0.4, %c0.1)
}
block1() {
%9 : int = prim::Constant[value=2]()
-> (%c2.1, %9)
-> (%c0.1, %9)
}
%c0.5 : int = aten::add(%c0.4, %c2.1)
%c0.6 : int = aten::add(%c0.5, %c0.1)
%11 : int = prim::Constant[value=5]()
%12 : Dynamic = aten::add(%a, %c0.5, %c2.1)
%13 : Dynamic = aten::add(%12, %c1, %c2.1)
%14 : Dynamic = aten::add(%13, %11, %c2.1)
%12 : Dynamic = aten::add(%a, %c0.6, %c0.1)
%13 : Dynamic = aten::add(%12, %c1, %c0.1)
%14 : Dynamic = aten::add(%13, %11, %c0.1)
return (%14);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ graph() {
%0 : bool = prim::Constant[value=0]()
%1 : bool = prim::Constant[value=1]()
%b.1 : int = prim::Constant[value=0]()
%3 : int = prim::Constant[value=2147483647]()
%3 : int = prim::Constant[value=9223372036854775807]()
%b.2 : int = prim::Constant[value=1]()
%b.4 : int = prim::Constant[value=2]()
%b.3 : int = prim::Loop(%3, %1, %b.1)
Expand Down
6 changes: 3 additions & 3 deletions test/expect/TestJit.test_export_expand_aten_fallback.expect
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ModelProto {
graph:
GraphProto {
name: "torch-jit-export"
inputs: [{name: "y.1", type:Tensor dims: 3 4 1}]
inputs: [{name: "x", type:Tensor dims: 3 4 1}]
outputs: [{name: "6", type:Tensor dims: 3 4 4}]
initializers: []
nodes: [
Expand All @@ -14,7 +14,7 @@ ModelProto {
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Loop", inputs: [3,2,y.1], outputs: [6], attributes: [{ name: 'body', type: graph, value:
Node {type: "Loop", inputs: [3,2,x], outputs: [6], attributes: [{ name: 'body', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "9", type:Tensor dims: }]
Expand All @@ -25,7 +25,7 @@ ModelProto {
Node {type: "Unsqueeze", inputs: [5], outputs: [11], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Unsqueeze", inputs: [i], outputs: [12], attributes: [{ name: 'axes', type: ints, values: [0]}]},
Node {type: "Concat", inputs: [10,11,12], outputs: [13], attributes: [{ name: 'axis', type: int, value: 0}]},
Node {type: "ATen", inputs: [y.1,13,1], outputs: [15], attributes: [{ name: 'operator', type: string, value: 'expand'}]}
Node {type: "ATen", inputs: [x,13,1], outputs: [15], attributes: [{ name: 'operator', type: string, value: 'expand'}]}
]
}

Expand Down
8 changes: 4 additions & 4 deletions test/expect/TestJit.test_function_default_values-bool.expect
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
graph(%result.1 : Dynamic
graph(%x : Dynamic
%a : Dynamic
%flag : Dynamic) {
%3 : int = prim::Constant[value=1]()
%4 : bool = prim::TensorToBool(%flag)
%result : Dynamic = prim::If(%4)
block0() {
-> (%result.1)
-> (%x)
}
block1() {
%result.2 : Dynamic = aten::add(%result.1, %a, %3)
-> (%result.2)
%result.1 : Dynamic = aten::add(%x, %a, %3)
-> (%result.1)
}
return (%result);
}
12 changes: 6 additions & 6 deletions test/expect/TestJit.test_pretty_printer-if_one.expect
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
def script(c2, c1):
t2 = aten::lt(c2, c1)
t3 = prim::TensorToBool(t2)
if t3:
c = c2
def script(
a: Tensor,
b: Tensor) -> Tensor:
if bool(aten.lt(a, b)):
c = a
else:
c = c1
c = b
return c
12 changes: 6 additions & 6 deletions test/expect/TestJit.test_pretty_printer-if_test.expect
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
def script(c2, c1):
t2 = aten::lt(c2, c1)
t3 = prim::TensorToBool(t2)
if t3:
c = c1
def script(
a: Tensor,
b: Tensor) -> Tensor:
if bool(aten.lt(a, b)):
c = b
else:
c = c2
c = a
return c
25 changes: 10 additions & 15 deletions test/expect/TestJit.test_pretty_printer-loop_use_test.expect
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
def script(y1):
x = aten::add(y1, 1, 1)
z1 = aten::add(x, 5, 1)
t9 = aten::lt(y1, 8)
t10 = prim::TensorToBool(t9)
y = y1
z = z1
t11 = t10
while t11:
y2 = aten::add_(y, 1, 1)
t17 = aten::lt(y2, 8)
t11 = prim::TensorToBool(t17)
y = y2
z = x
return (x, z)
def script(
y_1: Tensor) -> Tuple[Tensor, Tensor]:
x = aten.add(y_1, 1, 1)
z_1 = aten.add(x, 5, 1)
y, z = y_1, z_1
t = bool(aten.lt(y_1, 8))
while t:
y_2 = aten.add_(y, 1, 1)
t, y, z = bool(aten.lt(y_2, 8)), y_2, x
return x, z
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
def script(y):
def script(
y: Tensor) -> Tensor:
return ^python_fn()(y)
Loading

0 comments on commit aef9e76

Please sign in to comment.