Skip to content

Commit

Permalink
Update linspace and bump version nuymber to 8 (pytorch#71486)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#71486

This PR adds upgraders for linspace and linspace.out as the optional step size will be deprecated soon. Old models will be using steps size of 100 when nothing is provided.

Test Plan: buck-out/gen/caffe2/test/jit#binary.par -r TestUpgraders.test_aten_linspace

Reviewed By: cccclai, mruberry

Differential Revision: D33654308

fbshipit-source-id: 0e0138091da0b11d4f49156eeb6bcd7e46102a5b
(cherry picked from commit 931ae4a)
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Feb 1, 2022
1 parent 5024c1b commit b28e696
Show file tree
Hide file tree
Showing 16 changed files with 365 additions and 57 deletions.
11 changes: 1 addition & 10 deletions aten/src/ATen/native/RangeFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,8 @@
namespace at { namespace native {


Tensor& linspace_out(const Scalar& start, const Scalar& end, c10::optional<int64_t> optional_steps, Tensor& result) {
const auto steps = optional_steps.value_or(100);
Tensor& linspace_out(const Scalar& start, const Scalar& end, int64_t steps, Tensor& result) {
TORCH_CHECK(steps >= 0, "number of steps must be non-negative");

if (!optional_steps.has_value()) {
TORCH_WARN_ONCE(
"Not providing a value for linspace's steps is deprecated and will "
"throw a runtime error in a future release. This warning will appear "
"only once per process.");
}

if (result.numel() != steps) {
result.resize_({steps});
}
Expand Down
7 changes: 3 additions & 4 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,18 +552,17 @@ TensorOptions linspace_logspace_infer_options(
Tensor linspace(
const Scalar& start,
const Scalar& end,
c10::optional<int64_t> steps,
int64_t steps,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);

const auto steps_ = steps.value_or(100);
TORCH_CHECK(steps_ >= 0, "number of steps must be non-negative");
TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
auto result_options = linspace_logspace_infer_options(start, end, options, "torch.linspace()");
Tensor result = at::empty({steps_}, result_options);
Tensor result = at::empty({steps}, result_options);
return at::linspace_out(result, start, end, steps);
}

Expand Down
10 changes: 1 addition & 9 deletions aten/src/ATen/native/cuda/RangeFactories.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,9 @@ void gpu_kernel_with_index(at::Tensor &output, func_t f) {
namespace at {
namespace native {

Tensor& linspace_cuda_out(const Scalar& start, const Scalar& end, c10::optional<int64_t> optional_steps, Tensor& result) {
const auto steps = optional_steps.value_or(100);
Tensor& linspace_cuda_out(const Scalar& start, const Scalar& end, int64_t steps, Tensor& result) {
TORCH_CHECK(steps >= 0, "number of steps must be non-negative");

if (!optional_steps.has_value()) {
TORCH_WARN_ONCE(
"Not providing a value for linspace's steps is deprecated and will "
"throw a runtime error in a future release. This warning will appear "
"only once per process.");
}

if (result.numel() != steps) {
result.resize_({steps});
}
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2613,9 +2613,9 @@

- func: ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)

- func: linspace(Scalar start, Scalar end, int? steps=None, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

- func: linspace.out(Scalar start, Scalar end, int? steps=None, *, Tensor(a!) out) -> Tensor(a!)
- func: linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, Meta: linspace_out
CUDA: linspace_cuda_out
Expand Down
20 changes: 16 additions & 4 deletions caffe2/serialize/versions.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace serialize {
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;

#if ENABLE_UPGRADERS
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x7L;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x8L;
#else
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
#endif
Expand Down Expand Up @@ -58,13 +58,25 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x6L;
// 6. Write version string to `./data/version` instead of `version`.

#if ENABLE_UPGRADERS
// This is set to 7 from 3 due to a different interpretation of what
// file format version is. Whenever there is new upgrader introduced,
// [12/15/2021]
// kProducedFileFormatVersion is set to 7 from 3 due to a different
// interpretation of what file format version is.
// Whenever there is new upgrader introduced,
// this number should be bumped.
// The reasons that version is bumped in the past:
// 1. aten::div is changed at version 4
// 2. aten::full is changed at version 5
// 3. torch.package uses version 6
constexpr uint64_t kProducedFileFormatVersion = 0x7L;
// 4. Introduce new upgrader design and set the version number to 7
// mark this change
// --------------------------------------------------
// We describe new operator version bump reasons here:
// 1) [01/24/2022]
// We bump the version number to 8 to update aten::linspace
// and aten::linspace.out to error out when steps is not
// provided. (see: https://github.com/pytorch/pytorch/issues/55951)
// 2) ...
constexpr uint64_t kProducedFileFormatVersion = 0x8L;
#else
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
("aten::linalg_svd_out", datetime.date(2022, 3, 31)),
("aten::_max_pool1d_cpu_forward", datetime.date(2022, 2, 8)),
("aten::_convolution_nogroup", datetime.date(9999, 1, 1)),
("aten::linspace", datetime.date(2022, 3, 1)), # TODO this will be removed soon
("aten::miopen_convolution_backward", datetime.date(9999, 1, 1)),
("aten::miopen_convolution_backward_bias", datetime.date(9999, 1, 1)),
("aten::miopen_convolution_backward_input", datetime.date(9999, 1, 1)),
Expand Down
64 changes: 60 additions & 4 deletions test/jit/test_save_load_for_op_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import hypothesis.strategies as st
from hypothesis import example, settings, given
from typing import Union

import torch

Expand Down Expand Up @@ -240,10 +241,6 @@ def forward(self, a, b: int):
except Exception as e:
self.skipTest("Failed to load fixture!")

for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument

current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat)
current_mobile_module_int = self._save_load_mobile_module(MyModuleInt)

Expand Down Expand Up @@ -425,3 +422,62 @@ def _helper(m, fn):
self.assertEqual(mr, hr)

_helper(v3_mobile_module, current_mobile_module)

def test_versioned_linspace(self):
class Module(torch.nn.Module):
def __init__(self):
super(Module, self).__init__()

def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
c = torch.linspace(a, b, steps=5)
d = torch.linspace(a, b, steps=100)
return c, d

scripted_module = torch.jit.load(
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl")

buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
v7_mobile_module = _load_for_lite_interpreter(buffer)

current_mobile_module = self._save_load_mobile_module(Module)

sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
for (a, b) in sample_inputs:
(output_with_step, output_without_step) = v7_mobile_module(a, b)
(current_with_step, current_without_step) = current_mobile_module(a, b)
# when no step is given, should have used 100
self.assertTrue(output_without_step.size(dim=0) == 100)
self.assertTrue(output_with_step.size(dim=0) == 5)
# outputs should be equal to the newest version
self.assertEqual(output_with_step, current_with_step)
self.assertEqual(output_without_step, current_without_step)

def test_versioned_linspace_out(self):
class Module(torch.nn.Module):
def __init__(self):
super(Module, self).__init__()

def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor):
return torch.linspace(a, b, steps=100, out=out)

model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
loaded_model = torch.jit.load(model_path)
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
v7_mobile_module = _load_for_lite_interpreter(buffer)
current_mobile_module = self._save_load_mobile_module(Module)

sample_inputs = (
(3, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)),
(-10, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)),
(4.0, 6.0, torch.empty((100,), dtype=torch.float64), torch.empty((100,), dtype=torch.float64)),
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64), torch.empty((100,), dtype=torch.complex64)),
)
for (start, end, out_for_old, out_for_new) in sample_inputs:
output = v7_mobile_module(start, end, out_for_old)
output_current = current_mobile_module(start, end, out_for_new)
# when no step is given, should have used 100
self.assertTrue(output.size(dim=0) == 100)
# "Upgraded" model should match the new version output
self.assertEqual(output, output_current)
43 changes: 41 additions & 2 deletions test/jit/test_upgraders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@ def _load_model_version(self, loaded_model):
torch.jit.save(loaded_model, buffer)
buffer.seek(0)
zipped_model = zipfile.ZipFile(buffer)
version = int(zipped_model.read('archive/version').decode("utf-8"))
return version
# there was a change in how we store version number
# in a package between version 3 and 7.
# So we have to check for both.
try:
version = int(zipped_model.read('archive/version').decode("utf-8"))
return version
except KeyError:
version = int(zipped_model.read('archive/.data/version').decode("utf-8"))
return version

# TODO (tugsuu) We should ideally be generating this test cases.
def test_populated_upgrader_graph(self):
Expand Down Expand Up @@ -131,6 +138,38 @@ def test_func():
version = self._load_model_version(loaded_func)
self.assertTrue(version == 4)

@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_linspace(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
loaded_model = torch.jit.load(model_path)
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
for (a, b) in sample_inputs:
output_with_step, output_without_step = loaded_model(a, b)
# when no step is given, should have used 100
self.assertTrue(output_without_step.size(dim=0) == 100)
self.assertTrue(output_with_step.size(dim=0) == 5)

version = self._load_model_version(loaded_model)
self.assertTrue(version == 8)

@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_linspace_out(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
loaded_model = torch.jit.load(model_path)
sample_inputs = (
(3, 10, torch.empty((100,), dtype=torch.int64)),
(-10, 10, torch.empty((100,), dtype=torch.int64)),
(4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
)
for (a, b, c) in sample_inputs:
output = loaded_model(a, b, c)
# when no step is given, should have used 100
self.assertTrue(output.size(dim=0) == 100)

version = self._load_model_version(loaded_model)
self.assertTrue(version == 8)

@unittest.skipIf(not _is_upgraders_enabled(), "Skipping because upgraders are not enabled")
def test_aten_test_serialization(self):
model_path = pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
Expand Down
101 changes: 101 additions & 0 deletions test/mobile/test_upgrader_bytecode_table_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ getOperatorVersionMapForMobile() {
std::vector<Upgrader>({
Upgrader({0, 3, "div__Tensor_0_3", 3})
})},
{std::string("aten::linspace"),
std::vector<Upgrader>({
Upgrader({0, 7, "linspace_0_7", 7})
})},
{std::string("aten::linspace.out"),
std::vector<Upgrader>({
Upgrader({0, 7, "linspace_out_0_7", 8})
})},
});
return operatorVersionMapForMobile;
}
Expand Down Expand Up @@ -279,6 +287,99 @@ const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
OperatorString({"aten::div", "out_mode", 4}),
}), // operators list
}),
ByteCodeFunctionWithOperator({
mobile::Function::registerFunc(
"linspace_0_7",
std::vector<Instruction>({
Instruction{OpCode::STOREN, 1, 7},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::LOADC, 0, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::JF, 10, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
Instruction{OpCode::LOADC, 1, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::LOAD, 5, 0},
Instruction{OpCode::LOAD, 6, 0},
Instruction{OpCode::LOAD, 7, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::JMP, 10, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::OP, 2, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::LOAD, 5, 0},
Instruction{OpCode::LOAD, 6, 0},
Instruction{OpCode::LOAD, 7, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::STORE, 8, 0},
Instruction{OpCode::DROPR, 7, 0},
Instruction{OpCode::DROPR, 6, 0},
Instruction{OpCode::DROPR, 5, 0},
Instruction{OpCode::DROPR, 4, 0},
Instruction{OpCode::DROPR, 2, 0},
Instruction{OpCode::DROPR, 1, 0},
Instruction{OpCode::DROPR, 3, 0},
Instruction{OpCode::MOVE, 8, 0},
Instruction{OpCode::RET, 0, 0},
}), // instructions list,
std::vector<c10::IValue>({
c10::IValue(),
c10::IValue(100),
}), // constants list,
std::vector<c10::TypePtr>(), // types list,
8
),
std::vector<OperatorString>({
OperatorString({"aten::__is__", "", 2}),
OperatorString({"aten::linspace", "", 7}),
OperatorString({"prim::unchecked_cast", "", 1}),
}), // operators list
}),
ByteCodeFunctionWithOperator({
mobile::Function::registerFunc(
"linspace_out_0_7",
std::vector<Instruction>({
Instruction{OpCode::STOREN, 1, 4},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::LOADC, 0, 0},
Instruction{OpCode::OP, 0, 0},
Instruction{OpCode::JF, 7, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
Instruction{OpCode::LOADC, 1, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::JMP, 7, 0},
Instruction{OpCode::LOAD, 1, 0},
Instruction{OpCode::LOAD, 2, 0},
Instruction{OpCode::LOAD, 3, 0},
Instruction{OpCode::OP, 2, 0},
Instruction{OpCode::LOAD, 4, 0},
Instruction{OpCode::OP, 1, 0},
Instruction{OpCode::STORE, 5, 0},
Instruction{OpCode::DROPR, 4, 0},
Instruction{OpCode::DROPR, 2, 0},
Instruction{OpCode::DROPR, 1, 0},
Instruction{OpCode::DROPR, 3, 0},
Instruction{OpCode::MOVE, 5, 0},
Instruction{OpCode::RET, 0, 0},
}), // instructions list,
std::vector<c10::IValue>({
c10::IValue(),
c10::IValue(100),
}), // constants list,
std::vector<c10::TypePtr>(), // types list,
5
),
std::vector<OperatorString>({
OperatorString({"aten::__is__", "", 2}),
OperatorString({"aten::linspace", "out", 4}),
OperatorString({"prim::unchecked_cast", "", 1}),
}), // operators list
}),
});
for (const auto& upgrader_function : upgrader_function_list) {
for (const auto& op : upgrader_function.operators) {
Expand Down
Loading

0 comments on commit b28e696

Please sign in to comment.