diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 385c31e90ac66..04572f784b37a 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -57,15 +57,16 @@ namespace c10 { _(prim, IntToFloat) \ _(prim, FloatToInt) \ _(prim, StringToFloat) \ - _(prim, TensorDevice) \ - _(prim, TensorDType) \ - _(prim, TensorShape) \ + _(prim, device) \ + _(prim, dtype) \ + _(prim, shape) \ _(prim, AutogradAdd) \ _(prim, GradOf) \ _(prim, AnyDefined) \ _(prim, FusedConcat) \ _(prim, ConstantChunk) \ _(prim, NoneGenerator) \ + _(prim, MMTreeReduce) \ _(aten, floordiv) \ _(prim, fork) \ _(prim, RaiseException) \ diff --git a/test/expect/TestBatched.test_if_else.expect b/test/expect/TestBatched.test_if_else.expect index 2856146efc721..ea748cd947328 100644 --- a/test/expect/TestBatched.test_if_else.expect +++ b/test/expect/TestBatched.test_if_else.expect @@ -16,33 +16,33 @@ graph(%a.1_data : Dynamic %dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims) %16 : Long() = prim::NumToTensor(%6) %alpha : float = prim::TensorToNum(%16) - %data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha) + %data.5 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha) %mask : Dynamic = aten::mul(%a.1_mask, %b_mask) %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims) %21 : bool = prim::Constant[value=1]() %22 : int = prim::Constant[value=1]() %23 : Dynamic = aten::type_as(%8, %7) - %cond_mask.1 : Dynamic = aten::mul(%7, %23) - %25 : int = aten::dim(%cond_mask.1) + %data.2 : Dynamic = aten::mul(%7, %23) + %25 : int = aten::dim(%data.2) %26 : bool = aten::eq(%25, %22) %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%26) block0() { %30 : int = aten::dim(%data.1) %31 : int = aten::sub(%30, %22) - %data.3 : Dynamic = prim::Loop(%31, %21, %cond_mask.1) + %data.4 : Dynamic = prim::Loop(%31, %21, %data.2) block0(%_ : int, %34 : Dynamic) { %35 : int = aten::dim(%34) - %data.2 : Dynamic = aten::unsqueeze(%34, %35) - -> (%21, %data.2) + %data.3 : Dynamic = aten::unsqueeze(%34, %35) + -> (%21, %data.3) } - %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) - %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1) - -> (%cond_data.1, %cond_mask.2, %data.3) + %cond_data.1 : Dynamic = aten::expand_as(%data.4, %data.1) + %cond_mask.1 : Dynamic = aten::expand_as(%data.4, %mask.1) + -> (%cond_data.1, %cond_mask.1, %data.4) } block1() { - -> (%cond_mask.1, %cond_mask.1, %cond_mask.1) + -> (%data.2, %data.2, %data.2) } - %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4) + %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.5) %res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask) %res_dims : Dynamic = aten::__or__(%dims.1, %dims) return (%res_data, %res_mask, %res_dims); diff --git a/test/expect/TestBatched.test_if_else_with_scalar.expect b/test/expect/TestBatched.test_if_else_with_scalar.expect index 901fc44bc9028..ed95c3a6a192c 100644 --- a/test/expect/TestBatched.test_if_else_with_scalar.expect +++ b/test/expect/TestBatched.test_if_else_with_scalar.expect @@ -17,33 +17,33 @@ graph(%a.1_data : Dynamic %dims.1 : Dynamic = aten::__or__(%a.1_dims, %b_dims) %17 : Long() = prim::NumToTensor(%6) %alpha : float = prim::TensorToNum(%17) - %data.4 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha) + %data.5 : Dynamic = aten::sub(%a.1_data, %b_data, %alpha) %mask : Dynamic = aten::mul(%a.1_mask, %b_mask) %dims : Dynamic = aten::__or__(%a.1_dims, %b_dims) %22 : bool = prim::Constant[value=1]() %23 : int = prim::Constant[value=1]() %24 : Dynamic = aten::type_as(%a.1_mask, %10) - %cond_mask.1 : Dynamic = aten::mul(%10, %24) - %26 : int = aten::dim(%cond_mask.1) + %data.2 : Dynamic = aten::mul(%10, %24) + %26 : int = aten::dim(%data.2) %27 : bool = aten::eq(%26, %23) %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%27) block0() { %31 : int = aten::dim(%data.1) %32 : int = aten::sub(%31, %23) - %data.3 : Dynamic = prim::Loop(%32, %22, %cond_mask.1) + %data.4 : Dynamic = prim::Loop(%32, %22, %data.2) block0(%_ : int, %35 : Dynamic) { %36 : int = aten::dim(%35) - %data.2 : Dynamic = aten::unsqueeze(%35, %36) - -> (%22, %data.2) + %data.3 : Dynamic = aten::unsqueeze(%35, %36) + -> (%22, %data.3) } - %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) - %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask.1) - -> (%cond_data.1, %cond_mask.2, %data.3) + %cond_data.1 : Dynamic = aten::expand_as(%data.4, %data.1) + %cond_mask.1 : Dynamic = aten::expand_as(%data.4, %mask.1) + -> (%cond_data.1, %cond_mask.1, %data.4) } block1() { - -> (%cond_mask.1, %cond_mask.1, %cond_mask.1) + -> (%data.2, %data.2, %data.2) } - %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.4) + %res_data : Dynamic = aten::where(%cond_data, %data.1, %data.5) %res_mask : Dynamic = aten::where(%cond_mask, %mask.1, %mask) %res_dims : Dynamic = aten::__or__(%dims.1, %dims) return (%res_data, %res_mask, %res_dims); diff --git a/test/expect/TestBatched.test_if_noelse.expect b/test/expect/TestBatched.test_if_noelse.expect index 70070998a2605..526b68596d41f 100644 --- a/test/expect/TestBatched.test_if_noelse.expect +++ b/test/expect/TestBatched.test_if_noelse.expect @@ -17,25 +17,25 @@ graph(%a.1_data : Dynamic %16 : bool = prim::Constant[value=1]() %17 : int = prim::Constant[value=1]() %18 : Dynamic = aten::type_as(%8, %7) - %cond_mask.1 : Dynamic = aten::mul(%7, %18) - %20 : int = aten::dim(%cond_mask.1) + %data.2 : Dynamic = aten::mul(%7, %18) + %20 : int = aten::dim(%data.2) %21 : bool = aten::eq(%20, %17) %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%21) block0() { %25 : int = aten::dim(%data.1) %26 : int = aten::sub(%25, %17) - %data.3 : Dynamic = prim::Loop(%26, %16, %cond_mask.1) + %data.4 : Dynamic = prim::Loop(%26, %16, %data.2) block0(%_ : int, %29 : Dynamic) { %30 : int = aten::dim(%29) - %data.2 : Dynamic = aten::unsqueeze(%29, %30) - -> (%16, %data.2) + %data.3 : Dynamic = aten::unsqueeze(%29, %30) + -> (%16, %data.3) } - %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) - %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask) - -> (%cond_data.1, %cond_mask.2, %data.3) + %cond_data.1 : Dynamic = aten::expand_as(%data.4, %data.1) + %cond_mask.1 : Dynamic = aten::expand_as(%data.4, %mask) + -> (%cond_data.1, %cond_mask.1, %data.4) } block1() { - -> (%cond_mask.1, %cond_mask.1, %cond_mask.1) + -> (%data.2, %data.2, %data.2) } %res_data : Dynamic = aten::where(%cond_data, %data.1, %a.1_data) %res_mask : Dynamic = aten::where(%cond_mask, %mask, %a.1_mask) diff --git a/test/expect/TestBatched.test_if_noelse_with_scalar.expect b/test/expect/TestBatched.test_if_noelse_with_scalar.expect index a1e6841f02b51..6951dc14322af 100644 --- a/test/expect/TestBatched.test_if_noelse_with_scalar.expect +++ b/test/expect/TestBatched.test_if_noelse_with_scalar.expect @@ -18,25 +18,25 @@ graph(%a.1_data : Dynamic %17 : bool = prim::Constant[value=1]() %18 : int = prim::Constant[value=1]() %19 : Dynamic = aten::type_as(%a.1_mask, %10) - %cond_mask.1 : Dynamic = aten::mul(%10, %19) - %21 : int = aten::dim(%cond_mask.1) + %data.2 : Dynamic = aten::mul(%10, %19) + %21 : int = aten::dim(%data.2) %22 : bool = aten::eq(%21, %18) %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%22) block0() { %26 : int = aten::dim(%data.1) %27 : int = aten::sub(%26, %18) - %data.3 : Dynamic = prim::Loop(%27, %17, %cond_mask.1) + %data.4 : Dynamic = prim::Loop(%27, %17, %data.2) block0(%_ : int, %30 : Dynamic) { %31 : int = aten::dim(%30) - %data.2 : Dynamic = aten::unsqueeze(%30, %31) - -> (%17, %data.2) + %data.3 : Dynamic = aten::unsqueeze(%30, %31) + -> (%17, %data.3) } - %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) - %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask) - -> (%cond_data.1, %cond_mask.2, %data.3) + %cond_data.1 : Dynamic = aten::expand_as(%data.4, %data.1) + %cond_mask.1 : Dynamic = aten::expand_as(%data.4, %mask) + -> (%cond_data.1, %cond_mask.1, %data.4) } block1() { - -> (%cond_mask.1, %cond_mask.1, %cond_mask.1) + -> (%data.2, %data.2, %data.2) } %res_data : Dynamic = aten::where(%cond_data, %data.1, %a.1_data) %res_mask : Dynamic = aten::where(%cond_mask, %mask, %a.1_mask) diff --git a/test/expect/TestBatched.test_while.expect b/test/expect/TestBatched.test_while.expect index b55f9dd6a83ee..dd2a592a0990f 100644 --- a/test/expect/TestBatched.test_while.expect +++ b/test/expect/TestBatched.test_while.expect @@ -5,7 +5,7 @@ graph(%a.1_data : Dynamic %b_mask : Dynamic %b_dims : Dynamic) { %6 : int = prim::Constant[value=1]() - %7 : int = prim::Constant[value=2147483647]() + %7 : int = prim::Constant[value=9223372036854775807]() %8 : Dynamic = aten::gt(%a.1_data, %b_data) %9 : Dynamic = aten::mul(%a.1_mask, %b_mask) %10 : Dynamic = aten::__or__(%a.1_dims, %b_dims) @@ -16,7 +16,7 @@ graph(%a.1_data : Dynamic %15 : Dynamic = aten::gt(%14, %12) %16 : bool = prim::TensorToBool(%15) %17 : Dynamic, %18 : Dynamic, %19 : Dynamic, %a : Dynamic, %21 : Dynamic, %22 : Dynamic = prim::Loop(%7, %16, %8, %9, %10, %a.1_data, %a.1_mask, %a.1_dims) - block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) { + block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.2 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) { %30 : Long() = prim::NumToTensor(%6) %alpha : float = prim::TensorToNum(%30) %data.1 : Dynamic = aten::sub(%6_data, %b_data, %alpha) @@ -28,26 +28,26 @@ graph(%a.1_data : Dynamic %38 : bool = prim::TensorToBool(%35) %39 : bool = prim::Constant[value=1]() %40 : int = prim::Constant[value=1]() - %41 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2) - %cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %41) - %43 : int = aten::dim(%cond_mask.1) + %41 : Dynamic = aten::type_as(%cond_mask.2, %cond_data.2) + %data.2 : Dynamic = aten::mul(%cond_data.2, %41) + %43 : int = aten::dim(%data.2) %44 : bool = aten::eq(%43, %40) %cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%44) block0() { %48 : int = aten::dim(%data.1) %49 : int = aten::sub(%48, %40) - %data.3 : Dynamic = prim::Loop(%49, %39, %cond_mask.1) + %data.4 : Dynamic = prim::Loop(%49, %39, %data.2) block0(%_ : int, %52 : Dynamic) { %53 : int = aten::dim(%52) - %data.2 : Dynamic = aten::unsqueeze(%52, %53) - -> (%39, %data.2) + %data.3 : Dynamic = aten::unsqueeze(%52, %53) + -> (%39, %data.3) } - %cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1) - %cond_mask.2 : Dynamic = aten::expand_as(%data.3, %mask) - -> (%cond_data.1, %cond_mask.2, %data.3) + %cond_data.1 : Dynamic = aten::expand_as(%data.4, %data.1) + %cond_mask.1 : Dynamic = aten::expand_as(%data.4, %mask) + -> (%cond_data.1, %cond_mask.1, %data.4) } block1() { - -> (%cond_mask.1, %cond_mask.1, %cond_mask.1) + -> (%data.2, %data.2, %data.2) } %res_data : Dynamic = aten::where(%cond_data, %data.1, %6_data) %res_mask : Dynamic = aten::where(%cond_mask, %mask, %6_mask) diff --git a/test/expect/TestJit.test_constant_prop_if_constant.expect b/test/expect/TestJit.test_constant_prop_if_constant.expect index d1e11f4ab70cb..ed9113f240173 100644 --- a/test/expect/TestJit.test_constant_prop_if_constant.expect +++ b/test/expect/TestJit.test_constant_prop_if_constant.expect @@ -1,28 +1,28 @@ graph(%a : Dynamic %b : Dynamic) { - %c2.1 : int = prim::Constant[value=1]() + %c0.1 : int = prim::Constant[value=1]() %3 : bool = prim::TensorToBool(%a) - %c0.4 : int, %c1 : int = prim::If(%3) + %c0.5 : int, %c1 : int = prim::If(%3) block0() { %6 : bool = prim::TensorToBool(%b) - %c0.3 : int = prim::If(%6) + %c0.4 : int = prim::If(%6) block0() { %8 : int = prim::Constant[value=2]() -> (%8) } block1() { - -> (%c2.1) + -> (%c0.1) } - -> (%c0.3, %c2.1) + -> (%c0.4, %c0.1) } block1() { %9 : int = prim::Constant[value=2]() - -> (%c2.1, %9) + -> (%c0.1, %9) } - %c0.5 : int = aten::add(%c0.4, %c2.1) + %c0.6 : int = aten::add(%c0.5, %c0.1) %11 : int = prim::Constant[value=5]() - %12 : Dynamic = aten::add(%a, %c0.5, %c2.1) - %13 : Dynamic = aten::add(%12, %c1, %c2.1) - %14 : Dynamic = aten::add(%13, %11, %c2.1) + %12 : Dynamic = aten::add(%a, %c0.6, %c0.1) + %13 : Dynamic = aten::add(%12, %c1, %c0.1) + %14 : Dynamic = aten::add(%13, %11, %c0.1) return (%14); } diff --git a/test/expect/TestJit.test_constant_prop_loop_constant.expect b/test/expect/TestJit.test_constant_prop_loop_constant.expect index d2d1e53801612..63a42c652c70d 100644 --- a/test/expect/TestJit.test_constant_prop_loop_constant.expect +++ b/test/expect/TestJit.test_constant_prop_loop_constant.expect @@ -2,7 +2,7 @@ graph() { %0 : bool = prim::Constant[value=0]() %1 : bool = prim::Constant[value=1]() %b.1 : int = prim::Constant[value=0]() - %3 : int = prim::Constant[value=2147483647]() + %3 : int = prim::Constant[value=9223372036854775807]() %b.2 : int = prim::Constant[value=1]() %b.4 : int = prim::Constant[value=2]() %b.3 : int = prim::Loop(%3, %1, %b.1) diff --git a/test/expect/TestJit.test_export_expand_aten_fallback.expect b/test/expect/TestJit.test_export_expand_aten_fallback.expect index 4047f9c16b549..611d62727ddbb 100644 --- a/test/expect/TestJit.test_export_expand_aten_fallback.expect +++ b/test/expect/TestJit.test_export_expand_aten_fallback.expect @@ -5,7 +5,7 @@ ModelProto { graph: GraphProto { name: "torch-jit-export" - inputs: [{name: "y.1", type:Tensor dims: 3 4 1}] + inputs: [{name: "x", type:Tensor dims: 3 4 1}] outputs: [{name: "6", type:Tensor dims: 3 4 4}] initializers: [] nodes: [ @@ -14,7 +14,7 @@ ModelProto { Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Constant", inputs: [], outputs: [4], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, - Node {type: "Loop", inputs: [3,2,y.1], outputs: [6], attributes: [{ name: 'body', type: graph, value: + Node {type: "Loop", inputs: [3,2,x], outputs: [6], attributes: [{ name: 'body', type: graph, value: GraphProto { name: "torch-jit-export1" inputs: [{name: "i", type:Tensor dims: },{name: "cond", type:Tensor dims: },{name: "9", type:Tensor dims: }] @@ -25,7 +25,7 @@ ModelProto { Node {type: "Unsqueeze", inputs: [5], outputs: [11], attributes: [{ name: 'axes', type: ints, values: [0]}]}, Node {type: "Unsqueeze", inputs: [i], outputs: [12], attributes: [{ name: 'axes', type: ints, values: [0]}]}, Node {type: "Concat", inputs: [10,11,12], outputs: [13], attributes: [{ name: 'axis', type: int, value: 0}]}, - Node {type: "ATen", inputs: [y.1,13,1], outputs: [15], attributes: [{ name: 'operator', type: string, value: 'expand'}]} + Node {type: "ATen", inputs: [x,13,1], outputs: [15], attributes: [{ name: 'operator', type: string, value: 'expand'}]} ] } diff --git a/test/expect/TestJit.test_function_default_values-bool.expect b/test/expect/TestJit.test_function_default_values-bool.expect index 4e24f593875e0..fd8bbf553ea28 100644 --- a/test/expect/TestJit.test_function_default_values-bool.expect +++ b/test/expect/TestJit.test_function_default_values-bool.expect @@ -1,15 +1,15 @@ -graph(%result.1 : Dynamic +graph(%x : Dynamic %a : Dynamic %flag : Dynamic) { %3 : int = prim::Constant[value=1]() %4 : bool = prim::TensorToBool(%flag) %result : Dynamic = prim::If(%4) block0() { - -> (%result.1) + -> (%x) } block1() { - %result.2 : Dynamic = aten::add(%result.1, %a, %3) - -> (%result.2) + %result.1 : Dynamic = aten::add(%x, %a, %3) + -> (%result.1) } return (%result); } diff --git a/test/expect/TestJit.test_pretty_printer-if_one.expect b/test/expect/TestJit.test_pretty_printer-if_one.expect index cbede6c812757..26198231deeb0 100644 --- a/test/expect/TestJit.test_pretty_printer-if_one.expect +++ b/test/expect/TestJit.test_pretty_printer-if_one.expect @@ -1,8 +1,8 @@ -def script(c2, c1): - t2 = aten::lt(c2, c1) - t3 = prim::TensorToBool(t2) - if t3: - c = c2 +def script( + a: Tensor, + b: Tensor) -> Tensor: + if bool(aten.lt(a, b)): + c = a else: - c = c1 + c = b return c diff --git a/test/expect/TestJit.test_pretty_printer-if_test.expect b/test/expect/TestJit.test_pretty_printer-if_test.expect index c7e2249078845..468f6a3396415 100644 --- a/test/expect/TestJit.test_pretty_printer-if_test.expect +++ b/test/expect/TestJit.test_pretty_printer-if_test.expect @@ -1,8 +1,8 @@ -def script(c2, c1): - t2 = aten::lt(c2, c1) - t3 = prim::TensorToBool(t2) - if t3: - c = c1 +def script( + a: Tensor, + b: Tensor) -> Tensor: + if bool(aten.lt(a, b)): + c = b else: - c = c2 + c = a return c diff --git a/test/expect/TestJit.test_pretty_printer-loop_use_test.expect b/test/expect/TestJit.test_pretty_printer-loop_use_test.expect index 0dd3a473ac1ed..6d29c0c4931ee 100644 --- a/test/expect/TestJit.test_pretty_printer-loop_use_test.expect +++ b/test/expect/TestJit.test_pretty_printer-loop_use_test.expect @@ -1,15 +1,10 @@ -def script(y1): - x = aten::add(y1, 1, 1) - z1 = aten::add(x, 5, 1) - t9 = aten::lt(y1, 8) - t10 = prim::TensorToBool(t9) - y = y1 - z = z1 - t11 = t10 - while t11: - y2 = aten::add_(y, 1, 1) - t17 = aten::lt(y2, 8) - t11 = prim::TensorToBool(t17) - y = y2 - z = x - return (x, z) +def script( + y_1: Tensor) -> Tuple[Tensor, Tensor]: + x = aten.add(y_1, 1, 1) + z_1 = aten.add(x, 5, 1) + y, z = y_1, z_1 + t = bool(aten.lt(y_1, 8)) + while t: + y_2 = aten.add_(y, 1, 1) + t, y, z = bool(aten.lt(y_2, 8)), y_2, x + return x, z diff --git a/test/expect/TestJit.test_pretty_printer-python_op_name_test.expect b/test/expect/TestJit.test_pretty_printer-python_op_name_test.expect index efeab5617fdc9..fd5a0b8930af5 100644 --- a/test/expect/TestJit.test_pretty_printer-python_op_name_test.expect +++ b/test/expect/TestJit.test_pretty_printer-python_op_name_test.expect @@ -1,2 +1,3 @@ -def script(y): +def script( + y: Tensor) -> Tensor: return ^python_fn()(y) diff --git a/test/expect/TestJit.test_pretty_printer-while_if_test.expect b/test/expect/TestJit.test_pretty_printer-while_if_test.expect index 4295fae4a588d..fbd85b83e83b2 100644 --- a/test/expect/TestJit.test_pretty_printer-while_if_test.expect +++ b/test/expect/TestJit.test_pretty_printer-while_if_test.expect @@ -1,23 +1,14 @@ -def script(a1, b1): - t5 = aten::lt(a1, 10) - t6 = prim::TensorToBool(t5) - a = a1 - b = b1 - c = 0 - t7 = t6 - while t7: - a2 = aten::add(a, 1, 1) - b2 = aten::add(b, 1, 1) - t15 = aten::gt(a2, b2) - t16 = prim::TensorToBool(t15) - if t16: - c4 = 2 +def script( + a_1: Tensor, + b_1: Tensor) -> Tensor: + a, b, c = a_1, b_1, 0 + t = bool(aten.lt(a_1, 10)) + while t: + a_2 = aten.add(a, 1, 1) + b_2 = aten.add(b, 1, 1) + if bool(aten.gt(a_2, b_2)): + c_4 = 2 else: - c4 = 3 - t21 = aten::lt(a2, 10) - t7 = prim::TensorToBool(t21) - a = a2 - b = b2 - c = c4 - t27 = aten::add(a, 1, 1) - return aten::add(t27, c, 1) + c_4 = 3 + t, a, b, c = bool(aten.lt(a_2, 10)), a_2, b_2, c_4 + return aten.add(aten.add(a, 1, 1), c, 1) diff --git a/test/expect/TestJit.test_pretty_printer-while_test.expect b/test/expect/TestJit.test_pretty_printer-while_test.expect index 242479616f960..393137c461c8e 100644 --- a/test/expect/TestJit.test_pretty_printer-while_test.expect +++ b/test/expect/TestJit.test_pretty_printer-while_test.expect @@ -1,13 +1,10 @@ -def script(a1, i1): - t4 = aten::lt(i1, 3) - t5 = prim::TensorToBool(t4) - a = a1 - i = i1 - t6 = t5 - while t6: - i2 = aten::add_(i, 1, 1) - t13 = aten::lt(i2, 3) - t6 = prim::TensorToBool(t13) - a = aten::mul_(a, a) - i = i2 +def script( + a_1: Tensor, + i_1: Tensor) -> Tensor: + a, i = a_1, i_1 + t = bool(aten.lt(i_1, 3)) + while t: + a_2 = aten.mul_(a, a) + i_2 = aten.add_(i, 1, 1) + t, a, i = bool(aten.lt(i_2, 3)), a_2, i_2 return a diff --git a/test/expect/TestJit.test_recursive_cse.expect b/test/expect/TestJit.test_recursive_cse.expect index dd56a5119362b..ac5ca2420fde0 100644 --- a/test/expect/TestJit.test_recursive_cse.expect +++ b/test/expect/TestJit.test_recursive_cse.expect @@ -1,15 +1,15 @@ -graph(%z.1 : Dynamic +graph(%x : Dynamic %y : Dynamic) { %2 : int = prim::Constant[value=1]() - %3 : Dynamic = aten::add(%z.1, %y, %2) - %4 : Dynamic = aten::gt(%3, %z.1) + %3 : Dynamic = aten::add(%x, %y, %2) + %4 : Dynamic = aten::gt(%3, %x) %5 : bool = prim::TensorToBool(%4) %z : Dynamic = prim::If(%5) block0() { -> (%3) } block1() { - -> (%z.1) + -> (%x) } return (%z); } diff --git a/test/expect/TestScript.test_if_supertype.expect b/test/expect/TestScript.test_if_supertype.expect index e57f49a27cf67..e8b9bf40fa493 100644 --- a/test/expect/TestScript.test_if_supertype.expect +++ b/test/expect/TestScript.test_if_supertype.expect @@ -1,13 +1,13 @@ -graph(%y.1 : Float(*, *) - %z.2 : Long(*, *) +graph(%x.1 : Float(*, *) + %y.1 : Long(*, *) %z.1 : Float(*, *)) { %3 : bool = prim::Constant[value=1]() %x : Float(*, *), %y : Dynamic, %z : Dynamic = prim::If(%3) block0() { - -> (%y.1, %z.2, %z.1) + -> (%x.1, %y.1, %z.1) } block1() { - -> (%y.1, %y.1, %z.2) + -> (%x.1, %x.1, %y.1) } return (%x, %y, %z); } diff --git a/test/test_jit.py b/test/test_jit.py index fafacd5d3ec24..9bcc30e0ff717 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -234,6 +234,7 @@ def getExportImportCopy(self, m): buffer = io.BytesIO() torch.jit.save(imported, buffer) buffer.seek(0) + return torch.jit.load(buffer) def assertGraphContains(self, graph, kind): diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 83e5787b84d8c..dd7bcc9edd5d6 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -183,7 +183,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/passes/shape_analysis.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/specialize_undef.cpp - ${TORCH_SRC_DIR}/csrc/jit/passes/pretty_print.cpp + ${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp ${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp ${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp diff --git a/torch/csrc/jit/attributes.h b/torch/csrc/jit/attributes.h index af3a16393fe9d..04a64a64d5dd4 100644 --- a/torch/csrc/jit/attributes.h +++ b/torch/csrc/jit/attributes.h @@ -229,7 +229,7 @@ struct Attributes { return s; } - void printValue(std::ostream & out, Symbol & name) const { + void printValue(std::ostream & out, const Symbol & name) const { switch(kindOf(name)) { case AttributeKind::f: out << f(name); diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 5e5be199903ce..14b042f850854 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -6,7 +6,7 @@ #include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/assertions.h" #include "torch/csrc/jit/script/compiler.h" -#include "torch/csrc/jit/passes/pretty_print.h" +#include "torch/csrc/jit/passes/python_print.h" #include #include @@ -188,12 +188,12 @@ std::ostream& operator<<(std::ostream & out, const Graph & g) { } std::ostream& Graph::prettyPrint(std::ostream & out) { - PrettyPrint(out, *this); + PythonPrint(out, *this); return out; } void Graph::dumpPretty() { - PrettyPrint(std::cout, *this); + PythonPrint(std::cout, *this); } static void checkSameDevice(const Node* node) { @@ -1160,8 +1160,19 @@ inline const SourceRange& fakeRange() { return range; } -Value* Graph::insert(Symbol opname, at::ArrayRef args, at::ArrayRef kwargs) { - return script::emitBuiltinCall(fakeRange(), *this, opname, c10::nullopt, args, kwargs, /*required=*/true); +Value* Graph::insert( + Symbol opname, + at::ArrayRef args, + at::ArrayRef kwargs, + c10::optional range) { + return script::emitBuiltinCall( + range.value_or(fakeRange()), + *this, + opname, + c10::nullopt, + args, + kwargs, + /*required=*/true); } Node* Graph::create(NodeKind kind, size_t num_outputs) { diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 183a084ec52be..940b331fbdc67 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -868,7 +868,11 @@ friend struct Block; // argument matching rules, and checks that the op matches a known schema // if this node successfully completes, it guarentees the node is a correctly-formed invocation // of opname - Value* insert(Symbol opname, at::ArrayRef args, at::ArrayRef kwargs = {}); + Value* insert( + Symbol opname, + at::ArrayRef args, + at::ArrayRef kwargs = {}, + c10::optional range = {}); Node * appendNode(Node * n) { return block_->appendNode(n); diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 5186fc0d8da1c..4e00d700a216c 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -3,7 +3,7 @@ #include "torch/csrc/jit/script/lexer.h" #include "torch/csrc/jit/script/tree.h" #include "torch/csrc/jit/operator.h" - +#include "torch/csrc/jit/passes/python_print.h" #include "torch/csrc/jit/script/error_report.h" namespace torch { namespace jit { @@ -419,6 +419,16 @@ OperatorRegistry& getRegistry() { } // anonymous namespace void registerOperator(Operator&& op) { + if(op.schema().is_varret()) { + Symbol s = Symbol::fromQualString(op.schema().name()); + if (!printerHasSpecialCaseFor(s)) { + std::cout << c10::str( + "missing special case in python printer for non-schematized operator ", + op.schema().name(), + ". File a bug to add a case for this operator.\n"); + } + } + getRegistry().registerOperator(std::move(op)); } diff --git a/torch/csrc/jit/operator.h b/torch/csrc/jit/operator.h index 2bb0898fdf98d..508ab4e37509f 100644 --- a/torch/csrc/jit/operator.h +++ b/torch/csrc/jit/operator.h @@ -38,7 +38,7 @@ struct TORCH_API Operator { // arguments. This is used for things like prim::While or prim::If that can // take a number of different valid input types and lengths. Operator(Symbol name, OperationCreator op_creator) - : Operator(FunctionSchema(name, {}, {}, true), std::move(op_creator)) {} + : Operator(FunctionSchema(name, {}, {}, /*is_vararg*/true, /*is_varret*/true), std::move(op_creator)) {} Operator(FunctionSchema schema, Operation op) : schema_(std::make_shared(std::move(schema))), diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index d329491b88c34..0e4a0511175a8 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -90,7 +90,7 @@ bool shape_is_fast(const at::Tensor& lhs, const at::Tensor& rhs) { RegisterOperators mm_tree_reduction_reg({ Operator( - Symbol::prim("MMTreeReduce"), + prim::MMTreeReduce, [](const Node* node) { size_t num_inputs = node->inputs().size(); return [num_inputs](Stack& stack) { diff --git a/torch/csrc/jit/passes/pretty_print.cpp b/torch/csrc/jit/passes/pretty_print.cpp deleted file mode 100644 index 5140c5ba99a80..0000000000000 --- a/torch/csrc/jit/passes/pretty_print.cpp +++ /dev/null @@ -1,366 +0,0 @@ -#include "torch/csrc/jit/passes/pretty_print.h" -#include "torch/csrc/jit/attributes.h" -#include "torch/csrc/jit/generic_if.h" -#include "torch/csrc/jit/ir.h" - -namespace torch { -namespace jit { - -static std::ostream& indent(std::ostream& out, size_t level) { - for (size_t i = 0; i < level; ++i) { - out << " "; - } - return out; -} - -class PrettyPrintPass { - const Graph& graph_; - - // When printing a name if there is a conflict with an existing name in the - // graph, record the value -> new generated name mapping - std::unordered_map aliases_; - - // The Graph already tracks unique_names_, this is just for additional ones - // generated during printing - std::unordered_map generated_names_; - - // Cache of value names - std::unordered_map value_names_; - - // Nodes that were skipped to be printed later - std::unordered_set skipped_nodes_; - - template - void zipWith( - at::ArrayRef list_a, - at::ArrayRef list_b, - std::function action) const { - auto it_a = list_a.begin(); - auto it_b = list_b.begin(); - - if (list_a.size() != list_b.size()) { - AT_ERROR("Pretty printer expected 2 lists of same size"); - } - - for (; it_a != list_a.end(); ++it_a, ++it_b) { - action(*it_a, *it_b); - } - } - - std::ostream& printValueList( - std::ostream& out, - at::ArrayRef list) { - out << "("; - auto delimiter = ""; - for (const auto* value : list) { - out << delimiter; - printValue(out, value); - delimiter = ", "; - } - out << ")"; - return out; - } - - void printAssignment( - std::ostream& out, - const Value* lhs, - const Value* rhs, - const size_t level) { - indent(out, level); - printValue(out, lhs); - out << " = "; - printValue(out, rhs); - out << "\n"; - } - - std::ostream& printIf( - std::ostream& out, - const Node* node, - const size_t level) { - indent(out, level); - out << "if "; - const auto if_block = node->blocks()[0]; - const auto else_block = node->blocks()[1]; - printValue(out, node->inputs()[0]); - out << ":" - << "\n"; - - // Print node contents - printBlock(out, if_block, level + 1); - - // Print if block output - zipWith( - node->outputs(), - if_block->outputs(), - [&](const Value* node_output, const Value* return_input) { - printAssignment(out, node_output, return_input, level + 1); - }); - - indent(out, level); - out << "else:\n"; - printBlock(out, else_block, level + 1); - zipWith( - node->outputs(), - else_block->outputs(), - [&](const Value* node_output, const Value* return_input) { - printAssignment(out, node_output, return_input, level + 1); - }); - - return out; - } - - std::ostream& printLoop( - std::ostream& out, - const Node* node, - const size_t level) { - // Prints assignments between the loop node and body block around the - // loop body itself in the following manner: - // - // (silently) alias block input names to node output names - // assign each node input to the corresponding node output - // assign condition to loop condition value - // while ...: - // print loop body nodes - // assign each block output to the corresponding block input - - const auto body_block = node->blocks()[0]; - // Add aliases for loop-carried dependencies - zipWith( - body_block->inputs().slice(1), // Start at 1 to ignore trip count - node->outputs(), - [&](const Value* block_input, const Value* node_output) { - aliases_[block_input] = node_output; - }); - - // Print initial assignments of loop node outputs = loop node inputs - zipWith( - node->outputs(), - node->inputs().slice(2), - [&](const Value* node_output, const Value* node_input) { - printAssignment(out, node_output, node_input, level); - }); - - // Print condition initial assignment - printAssignment(out, body_block->inputs()[0], node->inputs()[1], level); - - // Loop header - indent(out, level); - out << "while "; - printValue(out, body_block->inputs()[0]); - out << ":\n"; - - // Loop body - printBlock(out, body_block, level + 1); - - // Update block outputs to block inputs for next loop iteration - zipWith( - body_block->inputs(), - body_block->outputs(), - [&](const Value* block_input, const Value* block_output) { - printAssignment(out, block_input, block_output, level + 1); - }); - return out; - } - - // Returns false if the node has no outputs, or if outputs are used anywhere - // except in a single prim::Return node - bool nodeOnlyOutputReturns(const Node* node) { - if (node->outputs().size() == 0) { - return false; - } - for (const auto* output : node->outputs()) { - if (output->uses().size() != 1) { - return false; - } - if (output->uses()[0].user->kind() != prim::Return) { - return false; - } - } - return true; - } - - std::ostream& printNode( - std::ostream& out, - const Node* node, - const size_t level) { - switch (node->kind()) { - case prim::Return: - break; - case prim::Constant: - break; - case prim::Loop: - printLoop(out, node, level); - break; - case prim::If: - printIf(out, node, level); - break; - default: - if (nodeOnlyOutputReturns(node)) { - // This node is assigned to a temp which is then used by prim::Return, - // so just print this node there - skipped_nodes_.insert(node); - return out; - } - - indent(out, level); - // Print outputs - if (node->outputs().size() > 0) { - auto delim = ""; - for (const auto* output_value : node->outputs()) { - out << delim; - printValue(out, output_value); - delim = ", "; - } - out << " = "; - } - - printRHS(out, node); - - out << "\n"; - } - - return out; - } - - // Prints the RHS value of a Node, e.g. `aten::add(x, y)` - std::ostream& printRHS(std::ostream& out, const Node* node) { - IR_IFM_CONST(node, PythonOp) - out << "^" << value->name(); - value->writeScalars(out); - IR_ELSE() - out << node->kind().toQualString(); - IR_END() - - // Print instruction parameters - printValueList(out, node->inputs()); - return out; - } - - std::ostream& printReturn( - std::ostream& out, - const Node* node, - const size_t level) { - indent(out, level); - const auto& returns = node->inputs(); - if (returns.size() > 0) { - out << "return "; - if (returns.size() > 1) { - printValueList(out, returns); - } else { - printValue(out, returns[0]); - } - out << "\n"; - } - return out; - } - - std::ostream& printBlock( - std::ostream& out, - const Block* root, - const size_t level) { - for (const auto* node : root->nodes()) { - printNode(out, node, level); - } - - printNode(out, root->return_node(), level); - - return out; - } - - inline bool isNameUnique(std::string& name, const Value* val) const { - auto generated_name_value = generated_names_.find(name); - if (generated_name_value != generated_names_.end() && - generated_name_value->second != val) { - // Found a generated name match, check that it's for a different value - return false; - } - return graph_.uniqueNames().find(name) == graph_.uniqueNames().end(); - } - - std::ostream& printValue(std::ostream& out, const Value* val) { - auto cached_name = value_names_.find(val); - if (cached_name != value_names_.end()) { - // If this value has been seen before, print out cached name - out << cached_name->second; - return out; - } - - const auto node = val->node(); - if (node->kind() == prim::Constant) { - // printAttributeValue(out, node->attributeNames()[0], node); - node->printValue(out, node->attributeNames()[0]); - return out; - } - - if (skipped_nodes_.count(node) > 0) { - skipped_nodes_.erase(node); - // Node was skipped earlier, so print it now - printRHS(out, node); - return out; - } - - auto name_source = val; - - auto aliased_name = aliases_.find(val); - if (aliased_name != aliases_.end()) { - name_source = aliased_name->second; - } - - auto name = name_source->uniqueName(); - - bool using_generated_name = false; - if (isdigit(name.at(0))) { - std::stringstream ss; - ss << "t" << name; - name = ss.str(); - using_generated_name = true; - } else if (name.find_last_of('.') != std::string::npos) { - // Make unique name a valid variable name (e.g. a.1 -> a1) - name.erase(std::remove(name.begin(), name.end(), '.'), name.end()); - using_generated_name = true; - } - - if (using_generated_name) { - // Make sure name is unique - size_t suffix = 0; - while (!isNameUnique(name, name_source)) { - std::stringstream ss; - ss << name << suffix; - name = ss.str(); - ++suffix; - } - - // These names aren't in the Graph's list of names but we still need to - // make sure there are no name conflicts - generated_names_[name] = name_source; - } - - value_names_[val] = name; - out << name; - return out; - } - - public: - PrettyPrintPass(const Graph& graph) : graph_(graph) {} - - std::ostream& run(std::ostream& out) { - out << "def script"; - const Node* params = graph_.block()->param_node(); - printValueList(out, params->outputs()); - out << ":\n"; - - // Print body - printBlock(out, graph_.block(), 1); - - printReturn(out, graph_.block()->return_node(), 1); - - return out; - } -}; - -TORCH_API std::ostream& PrettyPrint(std::ostream& out, const Graph& graph) { - return PrettyPrintPass(graph).run(out); -} - -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/passes/pretty_print.h b/torch/csrc/jit/passes/pretty_print.h deleted file mode 100644 index 786edbcb78f69..0000000000000 --- a/torch/csrc/jit/passes/pretty_print.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include "torch/csrc/jit/ir.h" - -namespace torch { namespace jit { - -TORCH_API std::ostream& PrettyPrint(std::ostream& out, const Graph& graph); - -}} diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp new file mode 100644 index 0000000000000..153d588e90c95 --- /dev/null +++ b/torch/csrc/jit/passes/python_print.cpp @@ -0,0 +1,587 @@ +#include "torch/csrc/jit/passes/python_print.h" +#include "torch/csrc/jit/attributes.h" +#include "torch/csrc/jit/generic_if.h" +#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/resource_guard.h" +#include "torch/csrc/jit/script/error_report.h" + +namespace torch { +namespace jit { + + +class PythonPrintPass { + std::ostream& out; + + // constants are written to this table, and given then named CONSTANTS.cN + // where N is the index into this table. + + std::vector tensor_constants; + // When printing this node, is it safe to write it inline (i.e. without + // assigning a temporary variable + std::unordered_set output_inline_; + + // when we print this, should we error if the resulting output would + // not be able to be reparsed? + bool enforce_importable_; + + // what valid identifiers are in use for the current function + std::unordered_set used_names_; + + // for fork, + // subgraphs get added to the worklist, and will be printed later + std::vector> worklist; + + // scanValue, scanNode, scanBlock: + // decide if it is safe to omit the output of a temporary variable, + // and inline the expression into its use + // we only do this if + // (1) it is a constant, or + // (2) the temporary is unnamed, is single output, is used once, + // and would appear in the same order when the expression tree is reparsed. + // The last case can be checked + // becuase when we emit a expresion tree in the parser, + // we do a left-to-right postorder traversal of the expression tree (emit children, then emit op). + // The reverse of this is a right-to-left preorder traversal of the tree. + // By doing a right-to-left preorder traversal of the inputs of a node, + // while also scanning the list of emitted nodes backward, we can see if + // they line up with what would happen when parsed the node as an expression. While they line + // up we collapse them into an inline expression. + + // The inductive step is that the right-most input should be produced by the node + // immediatly before the current node if it is in tree order. + + // block_point is the current node in the reverse linear scan of the emitted nodes + // v is the current value in the tree traversal that may match with block_point's output. + const Node* scanValue(const Node* block_point, const Value* v) { + const Node* n = v->node(); + JIT_ASSERT(n->kind() == prim::Constant || output_inline_.count(n) == 0); + + if (n == block_point && // the node must be at the expected point of the typical tree traversal + n->outputs().size() == 1 && // there must be only 1 values, otherwise we need an assignment to handle the multiple outout values + v->uses().size() == 1 && // if it is used more than once, then we need a variable + n->blocks().size() == 0 && // don't try to inline control blocks, + !v->hasUniqueName() && // if it has a name set, then it was written as a variable so preserve that + (v->uses().at(0).user->kind() != prim::Loop // if it is a loop-carried input, we need a variable + || v->uses().at(0).offset < 2)) { // otherwise the condition or trip count may be emitted in the wrong order w.r.t. to it + // recursively see if we can inline the inputs to this input + block_point = scanNode(block_point); + output_inline_.insert(n); + } else if (n->kind() == prim::Constant) { + // constant nodes can always be inlined, we will de-dup them on parsing + // and put them at the top of the function regardless + output_inline_.insert(n); + } + return block_point; + } + + const Node* scanNode(const Node* n) { + // don't bother to scan nodes we have already determined to be inline + if(output_inline_.count(n)) { + return n; + } + for(auto b : n->blocks()) { + scanBlock(b); + } + const Node* block_point = n->prev(); + for(auto it = n->inputs().rbegin(), + end = n->inputs().rend(); it != end; ++it) { + block_point = scanValue(block_point, *it); + } + return block_point; + } + + void scanBlock(const Block* b) { + scanNode(b->return_node()); + for(auto node : b->nodes().reverse()) { + scanNode(node); + } + } + + // get a new name unique across calls to uniqueName() and + // anything we have used. + size_t next_id = 0; + std::string genName(const std::string& candidate) { + // some names are valid identifiers but off limits because + // they are keywords or namespaces used in the output + const static std::unordered_set reserved_names = { + "aten", + "prim", + "CONSTANTS", + "int", + "float", + "bool", + "print", + "return", + "while", + "for", + "if", + "True", + "False", + "fork", + }; + + std::string name = candidate; + while(used_names_.count(name) || reserved_names.count(name)) { + name = candidate + std::to_string(next_id++); + } + used_names_.insert(name); + return name; + } + + // unique names might not be valid identifiers, + // force them to be by rewriting them + static std::string makeValidIdentifier(const std::string& candidate) { + std::stringstream ss; + if (candidate.size() == 0 || isdigit(candidate[0])) + ss << "_"; + for(char c : candidate) { + if (isupper(c) || islower(c) || isdigit(c) || c == '_') + ss << c; + else + ss << '_'; + } + return ss.str(); + } + // if we have to assign 'v' a name, what should it be? + // use the uniqueName if it was set, otherwise generate a name. + std::string genUniqueNameFor(const Value* v) { + return genName( + v->hasUniqueName() ? makeValidIdentifier(v->uniqueName()) : "t"); + } + + // map from Value to how it should be printed at each use + std::unordered_map value_names_; + + std::string useOf(const Value* v) const { + return value_names_.at(v); + } + void assignValue(const Value* v, const std::string& s) { + value_names_[v] = s; + } + void assignValue(const Value* v, const Value* w) { + assignValue(v, useOf(w)); + } + void assignValuesToTheirUniqueNames(at::ArrayRef values) { + for(auto v : values) { + assignValue(v, genUniqueNameFor(v)); + } + } + + size_t level = 0; + // indent to the current indent level + std::ostream& indent() { + for (size_t i = 0; i < level; ++i) { + out << " "; + } + return out; + } + + ResourceGuard WithIndented() { + level++; + return ResourceGuard([this]{ + level--; + }); + } + + template + void zipWith( + at::ArrayRef list_a, + at::ArrayRef list_b, + F action) const { + auto it_a = list_a.begin(); + auto it_b = list_b.begin(); + + if (list_a.size() != list_b.size()) { + AT_ERROR("Pretty printer expected 2 lists of same size"); + } + + for (; it_a != list_a.end(); ++it_a, ++it_b) { + action(*it_a, *it_b); + } + } + + void printValueList(std::ostream& stmt, at::ArrayRef list, const char* begin = "", const char* end = "") { + stmt << begin; + auto delimiter = ""; + for (const auto* value : list) { + stmt << delimiter; + stmt << useOf(value); + delimiter = ", "; + } + stmt << end; + } + + void printAssignment( + at::ArrayRef lhs, + at::ArrayRef rhs) { + if(lhs.size() > 0) { + indent(); + printValueList(out, lhs); + out << " = "; + printValueList(out, rhs); + out << "\n"; + } + } + + void printIf( + const Node* node) { + assignValuesToTheirUniqueNames(node->outputs()); + auto cond = node->inputs()[0]; + const auto if_block = node->blocks()[0]; + const auto else_block = node->blocks()[1]; + indent() << "if " << useOf(cond) << ":\n"; + { + auto guard = WithIndented(); + // Print node contents + printBlock(if_block); + printAssignment(node->outputs(), if_block->outputs()); + } + indent() << "else:\n"; + { + auto guard = WithIndented(); + printBlock(else_block); + printAssignment(node->outputs(), else_block->outputs()); + } + } + + // our way of encoding loops makes them difficult to turn back into python syntax. + // we have to check properties of the condition and trip count inputs to + // figure out which one it initially was + static bool shouldEmitAsForLoop(const Node* node) { + const auto body_block = node->blocks()[0]; + auto trip_count = toIValue(node->inputs().at(0)); + auto cond_input = toIValue(node->inputs().at(1)); + auto cond_next = toIValue(body_block->outputs().at(0)); + + bool condition_is_always_true = cond_input && cond_input->toBool() && cond_next && + cond_next->toBool(); + bool trip_count_is_specified = !trip_count || // trip is not a constant + trip_count->toInt() != std::numeric_limits::max() || // it is a constant but not the default one + body_block->inputs().at(0)->uses().size() > 0; // it is actually being used in the body. + + if (condition_is_always_true) { + // if the trip count was not specified this was a user-written while True: + return trip_count_is_specified; + } else { + // this must be a while loop, but check that there isn't _also_ a trip count + if (trip_count_is_specified) { + throw script::ErrorReport(node->getSourceLocation()) + << "loop cannot be printed as python because it has gone through an optimization " + << "that combined while and for loops. File a bug."; + } + return false; + } + } + + void printLoop(const Node* node) { + + // Loop carried dependencies are handled by assigning their initial + // values to the node->outputs() before the loop, + // and assign node->outputs() to the new values at the end of each trip. + + + bool emit_as_for_loop = shouldEmitAsForLoop(node); + const auto body_block = node->blocks()[0]; + + assignValuesToTheirUniqueNames(node->outputs()); + // Add aliases for loop-carried dependencies + zipWith( + body_block->inputs().slice(1), // Start at 1 to ignore trip count + node->outputs(), + [&](const Value* block_input, const Value* node_output) { + assignValue(block_input, node_output); + }); + + // Print initial assignments of loop node outputs = loop node inputs + printAssignment(node->outputs(), node->inputs().slice(2)); + + auto trip_count_in_block = body_block->inputs().at(0); + assignValuesToTheirUniqueNames(trip_count_in_block); + // Loop header + if (emit_as_for_loop) { + indent(); + out << "for " << useOf(trip_count_in_block) << " in range(" + << useOf(node->inputs().at(0)) << "):\n"; + } else { + // note: trip_count_in_block is unused because this is a while loop, + // so we reuse the Value* as a stand-in for the loop condition + printAssignment(trip_count_in_block, node->inputs().at(1)); + indent(); + out << "while " << useOf(trip_count_in_block) << ":\n"; + } + // Loop body + { + ResourceGuard indent = WithIndented(); + printBlock(body_block); + // Update block outputs to block inputs for next loop iteration + // skip the assignment to the new condition in for loops because + // the condition is always True + size_t offset = emit_as_for_loop ? 1 : 0; + printAssignment(body_block->inputs().slice(offset), body_block->outputs().slice(offset)); + } + } + + void printNode(const Node* node) { + switch (node->kind()) { + case prim::Return: + if (node->inputs().size() > 0) { + indent(); + out << "return "; + printValueList(out, node->inputs()); + out << "\n"; + } + break; + case prim::Loop: + printLoop(node); + break; + case prim::If: + printIf(node); + break; + case prim::TupleUnpack: + case prim::ListUnpack: + assignValuesToTheirUniqueNames(node->outputs()); + indent(); + // TupleUnpack(unpacked) turns into an assignment op that forces + // the unpack to be inserted when parsed back in: + // a, b, = unpacked + // a, = unpacked # trailing comma forces an unpack to happen + if (node->outputs().size() > 0) { + printValueList(out, node->outputs(), "", ", = "); + } + out << useOf(node->input()) << "\n"; + break; + default: + + std::stringstream ss; + printRHS(ss, node); + + // this node is safe to inline, so assign the output value + // to that expression directly + // guard against really long lines + if (output_inline_.count(node) > 0 && ss.str().size() + level * 2 < 40) { + assignValue(node->output(), ss.str()); + return; + } + assignValuesToTheirUniqueNames(node->outputs()); + indent(); + // Print outputs + if (node->outputs().size() > 0) { + printValueList(out, node->outputs()); + out << " = "; + } + out << ss.str() << "\n"; + } + } + + size_t addTensorConstant(at::Tensor t) { + tensor_constants.emplace_back(std::move(t)); + return tensor_constants.size() - 1; + } + + // Prints the RHS value of a Node, e.g. `aten.add(x, y)` + void printRHS(std::ostream& stmt, const Node* node) { + switch(node->kind()) { + case PythonOp::Kind: { + auto value = static_cast(node); + if (enforce_importable_) { + throw script::ErrorReport(node->getSourceLocation()) + << "could not export python function call " << value->name() + << ". Remove calls to python functions before export."; + } + + stmt << "^" << value->name(); + value->writeScalars(stmt); + printValueList(stmt, node->inputs(), "(", ")"); + } break; + case prim::Constant: { + IValue v = toIValue(node->output()).value(); + if(v.isTensor()) { + stmt << "CONSTANTS.c" << addTensorConstant(std::move(v).toTensor()); + } else if(v.isString()) { + // TODO: escape the string correctly by implementing a subset of + // string escapes in both printing and parsing. + stmt << "\"" << v.toStringRef() << "\""; + } else if(v.isTensorList()) { + auto tl = v.toTensorListRef(); + stmt << "["; + const char* delim = ""; + for(at::Tensor t : tl) { + stmt << delim << "CONSTANTS.c" << addTensorConstant(std::move(t)); + delim = ", "; + } + stmt << "]"; + } else { + // TODO: ensure floats always print with their periods + stmt << v; + } + } break; + case prim::None: + case prim::NoneGenerator: + case prim::Undefined: { + stmt << "None"; + } break; + case prim::FloatToInt: { + printValueList(stmt, node->inputs(), "int(", ")"); + } break; + case prim::StringToFloat: + case prim::IntToFloat: { + printValueList(stmt, node->inputs(), "float(", ")"); + } break; + case prim::TensorToBool: { + printValueList(stmt, node->inputs(), "bool(", ")"); + } break; + case prim::Print: { + printValueList(stmt, node->inputs(), "print(",")"); + } break; + case prim::TupleConstruct: { + printValueList( + stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")"); + } break; + case prim::TupleIndex: { + stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::index) << "]"; + } break; + case prim::TupleSlice: { + stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::beg) << ":" + << node->i(attr::end) << "]"; + } break; + case prim::ListConstruct: { + // TODO: when the list is empty and is not a list of tensors, + // we need to annotate it, otherwise it won't be possible + // to infer the type on import + printValueList(stmt, node->inputs(), "[", "]"); + } break; + case prim::fork: { + // the subgraph gets emitted as another function + auto name = genName("__forked_function"); + std::shared_ptr graph = node->g(attr::Subgraph); + worklist.emplace_back(graph.get(), name); + // and we put a call to fork which invokes that function. + stmt << "fork(" << name; + for(const Value* v : node->inputs()) { + stmt << ", " << useOf(v); + } + stmt << ")"; + } break; + default: { + Symbol kind = node->kind(); + stmt << kind.ns().toUnqualString() << "." << kind.toUnqualString(); + printValueList(stmt, node->inputs(), "(", ")"); + } break; + } + } + + std::ostream& printBlock( + const Block* root) { + for (const auto* node : root->nodes()) { + printNode(node); + } + return out; + } + + void printOneFunction(const Graph& graph, const std::string& name) { + used_names_.clear(); // each graph can reuse local names + // current graph is used to de-dup names within a single graph + scanBlock(graph.block()); + assignValuesToTheirUniqueNames(graph.inputs()); + out << "def " << name << "(\n"; + const char * delim = " "; + for(auto input : graph.inputs()) { + out << delim << useOf(input) << ": " << input->type()->python_str(); + delim = ",\n "; + } + out << ") -> " << resultType(graph)->python_str() << ":\n"; + { + auto guard = WithIndented(); + // Print body + printBlock(graph.block()); + printNode(graph.block()->return_node()); + } + } + + public: + PythonPrintPass( + std::ostream& out_, + bool enforce_importable = false) + : out(out_), enforce_importable_(enforce_importable) {} + + // TODO: we should consider forcing functions to return a single value + // instead of handling this tuple logic both in the compiler and the printer + TypePtr resultType(const Graph& graph) { + if (graph.outputs().size() == 1) { + return graph.outputs().at(0)->type(); + } else { + return TupleType::create( + fmap(graph.outputs(), [&](const Value* v) { return v->type(); })); + } + } + + void printFunction(const Graph& graph, const std::string& name) { + printOneFunction(graph, name); + while(!worklist.empty()) { + out << "\n\n"; + auto work = worklist.back(); + worklist.pop_back(); + printOneFunction(*work.first, work.second); + } + } +}; + +TORCH_API std::ostream& PythonPrint(std::ostream& out, const Graph& graph) { + PythonPrintPass(out).printFunction(graph, "script"); + return out; +} + +TORCH_API bool printerHasSpecialCaseFor(Symbol sym) { + // WARNING: by adding a value to this set, you are asserting + // that you have also added special handling of this symbol to + // the printer above. Not adding handling will cause import and export + // of modules with this new operator to fail. This is only required + // for operators without schema. Prefer registering your operator with + // schema to editing this list here. These cases should only be things + // that require special handling because they do not fit normal schema + const static std::unordered_set handled = { + prim::BoolToTensor, + prim::Constant, + prim::TensorToBool, + prim::FloatToInt, + prim::fork, + prim::IntToFloat, + prim::ListConstruct, + prim::ListUnpack, + prim::None, + prim::NoneGenerator, + prim::Print, + prim::PythonOp, + prim::StringToFloat, + prim::TupleConstruct, + prim::TupleIndex, + prim::TupleSlice, + prim::TupleUnpack, + prim::Undefined, + }; + + // WARNING: by adding a value to this set, you are asserting that your + // primitive is only ever added during optimization and does not need + // to be correctly printed for export (a process that happens before + // optimization passes run) + const static std::unordered_set unneeded = { + onnx::Reshape, // only used in onnx + onnx::Shape, // only used in onnx + prim::AnyDefined, // temporarily inserted by autograd + prim::AutogradAdd, // temporarily inserted by autograd + prim::ConstantChunk, // optimization pass adds it + prim::DifferentiableGraph, // optimization pass adds it + prim::Drop, // used in interpreter only + prim::FusedConcat, // optimization pass adds it + prim::FusionGroup, // optimization pass adds it + prim::Load, // used in interpreter only + prim::MMTreeReduce, // used in batched execution only + prim::Store, // used in interpreter only + + }; + + return handled.count(sym) || unneeded.count(sym); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/python_print.h b/torch/csrc/jit/passes/python_print.h new file mode 100644 index 0000000000000..263ed3a053379 --- /dev/null +++ b/torch/csrc/jit/passes/python_print.h @@ -0,0 +1,13 @@ +#pragma once +#include "torch/csrc/WindowsTorchApiMacro.h" +#include + +namespace c10 { + struct Symbol; +} + +namespace torch { namespace jit { +struct Graph; +TORCH_API std::ostream& PythonPrint(std::ostream& out, const Graph& graph); +TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym); +}} diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 050bd2d951914..503f52f058bb4 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -178,7 +178,7 @@ RegisterOperators reg({ }; }), Operator( - prim::TensorDevice, + "prim::device(Tensor a) -> int[]", [](const Node* node) -> Operation { return [](Stack& stack) { at::Tensor a; @@ -189,7 +189,7 @@ RegisterOperators reg({ }; }), Operator( - prim::TensorDType, + "prim::dtype(Tensor a) -> int", [](const Node* node) -> Operation { return [](Stack& stack) { at::Tensor a; @@ -199,7 +199,7 @@ RegisterOperators reg({ }; }), Operator( - prim::TensorShape, + "prim::shape(Tensor a) -> int[]", [](const Node* node) -> Operation { return [](Stack& stack) { at::Tensor a; @@ -250,7 +250,7 @@ RegisterOperators reg({ }; }), Operator( - prim::RaiseException, + "prim::RaiseException(str msg) -> ()", [](const Node* node) -> Operation { return [](Stack& stack) { throw JITException(pop(stack).toStringRef()); @@ -364,7 +364,7 @@ RegisterOperators reg({ const auto & elems = t->elements(); std::vector output_elems; for (int64_t i = beg_ind; i < end_ind; ++i) { - output_elems.push_back(elems.at(i)); + output_elems.emplace_back(elems.at(i)); } push(stack, Tuple::create(std::move(output_elems))); return 0; @@ -378,7 +378,7 @@ RegisterOperators reg({ auto tup = pop(stack).toTuple(); const auto & elems = tup->elements(); // index is normalized to be positive at compile time - stack.push_back(elems.at(index)); + stack.emplace_back(elems.at(index)); return 0; }; }), @@ -493,7 +493,7 @@ RegisterOperators reg({ std::vector vals; vals.reserve(num_inputs); for (size_t i = stack_size - num_inputs; i < stack_size; ++i) { - vals.push_back(std::move(stack[i]).toTensor()); + vals.emplace_back(std::move(stack[i]).toTensor()); } drop(stack, num_inputs); push(stack, std::move(vals)); @@ -505,7 +505,7 @@ RegisterOperators reg({ std::vector vals; vals.reserve(num_inputs); for (size_t i = stack_size - num_inputs; i < stack_size; ++i) { - vals.push_back(std::move(stack[i])); + vals.emplace_back(std::move(stack[i])); } drop(stack, num_inputs); push(stack, std::move(vals)); diff --git a/torch/csrc/jit/scope.cpp b/torch/csrc/jit/scope.cpp index a6613d6861dfc..de071f3a3b1d2 100644 --- a/torch/csrc/jit/scope.cpp +++ b/torch/csrc/jit/scope.cpp @@ -2,11 +2,7 @@ #include "torch/csrc/jit/operator.h" -#include "torch/csrc/autograd/function.h" -#include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/assertions.h" -#include "torch/csrc/jit/script/compiler.h" -#include "torch/csrc/jit/passes/pretty_print.h" #include #include diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index dce0fbab588ab..54533791e98c4 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -265,7 +265,7 @@ struct Environment { void setSugaredVar(const SourceRange& loc, const std::string& name, SugaredValuePtr value) { Value* as_simple_value = asSimple(value); - if (as_simple_value) + if (as_simple_value && !as_simple_value->hasUniqueName()) as_simple_value->setUniqueName(name); // prevent re-assignment involving any sugared values // any reassignment like: @@ -886,7 +886,6 @@ struct to_ir { << (schema.returns().size() > 1 ? "s" : "") << " but found no return statement"; } - method.setSchema({def.name().name(), std::move(arguments), std::move(returns)}); // remove any uses of tuples that we inserted that are not needed LowerSimpleTuples(graph); @@ -1182,7 +1181,7 @@ struct to_ir { max_trip_count->range(), emitExpr(max_trip_count.value())); } else { max_trip_count_val = - materializeConstant((int64_t)INT_MAX, *graph, range, integral_constants); + materializeConstant(std::numeric_limits::max(), *graph, range, integral_constants); } if (cond) { cond_val = emitCond(cond.value()); @@ -1319,8 +1318,7 @@ struct to_ir { void emitRaise(const SourceRange& loc) { const std::string exception = "Exception"; auto string_input = insertConstant(*graph, exception, loc); - graph->insertNode(graph->create(prim::RaiseException, {string_input}, 0) - ->setSourceLocation(std::make_shared(loc))); + graph->insert(prim::RaiseException, {string_input}, {}, loc); } void emitAssert(const Assert& stmt) { @@ -2288,18 +2286,16 @@ std::shared_ptr SimpleValue::attr(SourceRange loc, Method & m, con Symbol::aten(builtin_cast_methods().at(field)), NamedValue(loc, "self", value)); } - if (field == "dtype") { - auto* node = m.graph()->create(prim::TensorDType, {value}); - node->output()->setType(IntType::get()); - return std::make_shared(m.graph()->insertNode(node)->output()); - } else if (field == "device") { - auto* node = m.graph()->create(prim::TensorDevice, {value}); - node->output()->setType(ListType::create(IntType::get())); - return std::make_shared(m.graph()->insertNode(node)->output()); - } else if (field == "shape") { - auto* node = m.graph()->create(prim::TensorShape, {value}); - node->output()->setType(ListType::create(IntType::get())); - return std::make_shared(m.graph()->insertNode(node)->output()); + // functions that are just direct property lookups on tensor + // must be registered as prim::(Tensor t) -> + static const std::unordered_set fields = { + "dtype", + "device", + "shape", + }; + if (fields.count(field)) { + auto r = m.graph()->insert(Symbol::fromQualString("prim::"+field), {value}); + return std::make_shared(r); } } if (getValue()->type()->isSubtypeOf(NumberType::get())) { diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h index 66296e8623cc3..5e834c3aaacaf 100644 --- a/torch/csrc/jit/script/compiler.h +++ b/torch/csrc/jit/script/compiler.h @@ -129,17 +129,23 @@ struct TORCH_API BuiltinFunction : public SugaredValue { }; struct TORCH_API BuiltinModule : public SugaredValue { - BuiltinModule(std::string name) - : name(std::move(name)) {} - std::string name; + BuiltinModule(std::string name, + c10::optional version = at::nullopt) + : name(std::move(name)) + , version(version) {} std::string kind() const override { return "builtin module"; } - std::shared_ptr attr(SourceRange loc, Method & m, const std::string& field) override { - return std::make_shared(Symbol::aten(field), c10::nullopt); + return std::make_shared(Symbol::fromQualString(name+"::"+field), c10::nullopt); } + +private: + std::string name; + // when we add operator versioning, emit this op as it exising at 'version' + // if not set, use the latest version + c10::optional version; }; // These SugaredValues have special handling in the compiler because they @@ -163,7 +169,7 @@ using Resolver = std::function(const std::string& inline std::shared_ptr nativeResolver(const std::string& name, Method& m, const SourceRange& loc){ if (name == "torch") { - return std::make_shared(name); + return std::make_shared("aten"); } return nullptr; } diff --git a/torch/csrc/jit/script/parser.h b/torch/csrc/jit/script/parser.h index d9848712ee7d1..7ac71755a70f3 100644 --- a/torch/csrc/jit/script/parser.h +++ b/torch/csrc/jit/script/parser.h @@ -8,7 +8,7 @@ namespace jit { namespace script { -Decl mergeTypesFromTypeComment(Decl decl, Decl type_annotation_decl, bool is_method) { +inline Decl mergeTypesFromTypeComment(Decl decl, Decl type_annotation_decl, bool is_method) { auto expected_num_annotations = decl.params().size(); if (is_method) { // `self` argument