forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Get pretty printer ready for use as a serialization format (pytorch#1…
…3616) Summary: Get pretty printer ready for use as a serialization format This PR adds a bunch of functionality to the pretty printer (now called python_printer to reflect the fact that it will be used to output valid python source). The idea is to get the printer ready for use as serialization format. This PR does not have tests beyond what the pretty printer already had. PRs stacked on this one will do round-trip export/import to test this functionality more robustly. Notes: * PythonPrinter is an evolution of the original pretty printer. However, much of it has changed so it is best just to read it as a new implementation. Trying to correlate it to the original implementation is probably not much help. * The printer tries to get reasonably close to how the original function was likely written, such as writing expressions rather than making intermediates when possible. We may decide to turn this off for the actual serialization, but it is useful for pretty printing. * tensor field access was changed so that prim::device and family have schema * fixed a bug in the compiler where setUniqueName gets called even when a value already has one. this sometimes assigned really poor names to graph inputs * Graph::insert gains an optional range argument to make range-preserving inserts easier. * prim:: ops that can have schema now have schema. This is because when we parse them back in, we will need the schema to correctly set their output types. * there is code in the python printer to complain if you try to add a prim op and do not update the printer. * BuiltinModule is generalized to take an operator namespace and a version number for work in future commits. Pull Request resolved: pytorch#13616 Reviewed By: goldsborough Differential Revision: D13008252 Pulled By: zdevito fbshipit-source-id: 32b33bc6410d6ca1c6f02bd6e050f8d5eea32083
- Loading branch information
1 parent
b7a7ab3
commit aef9e76
Showing
35 changed files
with
799 additions
and
565 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
20 changes: 10 additions & 10 deletions
20
test/expect/TestJit.test_constant_prop_if_constant.expect
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,28 @@ | ||
graph(%a : Dynamic | ||
%b : Dynamic) { | ||
%c2.1 : int = prim::Constant[value=1]() | ||
%c0.1 : int = prim::Constant[value=1]() | ||
%3 : bool = prim::TensorToBool(%a) | ||
%c0.4 : int, %c1 : int = prim::If(%3) | ||
%c0.5 : int, %c1 : int = prim::If(%3) | ||
block0() { | ||
%6 : bool = prim::TensorToBool(%b) | ||
%c0.3 : int = prim::If(%6) | ||
%c0.4 : int = prim::If(%6) | ||
block0() { | ||
%8 : int = prim::Constant[value=2]() | ||
-> (%8) | ||
} | ||
block1() { | ||
-> (%c2.1) | ||
-> (%c0.1) | ||
} | ||
-> (%c0.3, %c2.1) | ||
-> (%c0.4, %c0.1) | ||
} | ||
block1() { | ||
%9 : int = prim::Constant[value=2]() | ||
-> (%c2.1, %9) | ||
-> (%c0.1, %9) | ||
} | ||
%c0.5 : int = aten::add(%c0.4, %c2.1) | ||
%c0.6 : int = aten::add(%c0.5, %c0.1) | ||
%11 : int = prim::Constant[value=5]() | ||
%12 : Dynamic = aten::add(%a, %c0.5, %c2.1) | ||
%13 : Dynamic = aten::add(%12, %c1, %c2.1) | ||
%14 : Dynamic = aten::add(%13, %11, %c2.1) | ||
%12 : Dynamic = aten::add(%a, %c0.6, %c0.1) | ||
%13 : Dynamic = aten::add(%12, %c1, %c0.1) | ||
%14 : Dynamic = aten::add(%13, %11, %c0.1) | ||
return (%14); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,15 @@ | ||
graph(%result.1 : Dynamic | ||
graph(%x : Dynamic | ||
%a : Dynamic | ||
%flag : Dynamic) { | ||
%3 : int = prim::Constant[value=1]() | ||
%4 : bool = prim::TensorToBool(%flag) | ||
%result : Dynamic = prim::If(%4) | ||
block0() { | ||
-> (%result.1) | ||
-> (%x) | ||
} | ||
block1() { | ||
%result.2 : Dynamic = aten::add(%result.1, %a, %3) | ||
-> (%result.2) | ||
%result.1 : Dynamic = aten::add(%x, %a, %3) | ||
-> (%result.1) | ||
} | ||
return (%result); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
def script(c2, c1): | ||
t2 = aten::lt(c2, c1) | ||
t3 = prim::TensorToBool(t2) | ||
if t3: | ||
c = c2 | ||
def script( | ||
a: Tensor, | ||
b: Tensor) -> Tensor: | ||
if bool(aten.lt(a, b)): | ||
c = a | ||
else: | ||
c = c1 | ||
c = b | ||
return c |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
def script(c2, c1): | ||
t2 = aten::lt(c2, c1) | ||
t3 = prim::TensorToBool(t2) | ||
if t3: | ||
c = c1 | ||
def script( | ||
a: Tensor, | ||
b: Tensor) -> Tensor: | ||
if bool(aten.lt(a, b)): | ||
c = b | ||
else: | ||
c = c2 | ||
c = a | ||
return c |
25 changes: 10 additions & 15 deletions
25
test/expect/TestJit.test_pretty_printer-loop_use_test.expect
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,10 @@ | ||
def script(y1): | ||
x = aten::add(y1, 1, 1) | ||
z1 = aten::add(x, 5, 1) | ||
t9 = aten::lt(y1, 8) | ||
t10 = prim::TensorToBool(t9) | ||
y = y1 | ||
z = z1 | ||
t11 = t10 | ||
while t11: | ||
y2 = aten::add_(y, 1, 1) | ||
t17 = aten::lt(y2, 8) | ||
t11 = prim::TensorToBool(t17) | ||
y = y2 | ||
z = x | ||
return (x, z) | ||
def script( | ||
y_1: Tensor) -> Tuple[Tensor, Tensor]: | ||
x = aten.add(y_1, 1, 1) | ||
z_1 = aten.add(x, 5, 1) | ||
y, z = y_1, z_1 | ||
t = bool(aten.lt(y_1, 8)) | ||
while t: | ||
y_2 = aten.add_(y, 1, 1) | ||
t, y, z = bool(aten.lt(y_2, 8)), y_2, x | ||
return x, z |
3 changes: 2 additions & 1 deletion
3
test/expect/TestJit.test_pretty_printer-python_op_name_test.expect
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
def script(y): | ||
def script( | ||
y: Tensor) -> Tensor: | ||
return ^python_fn()(y) |
Oops, something went wrong.