Skip to content

Commit

Permalink
[PyTorch Edge] Add backport to export old bytecode models (pytorch#56802
Browse files Browse the repository at this point in the history
)

Summary:
Add an api to backport a model vn to model vi. It accept an input model (file or buffer) and output a model (file or buffer) with an expected bytecode version.

In this change, the input is a model and it can come from a file or buffer. The output is a model and can be either file path or buffer.

When backport fails, function return false with a warning message :
```
/Users/chenlai/pytorch/cmake-build-debug/bin/test_jit --gtest_filter=LiteInterpreterTest.BackPortByteCodeModelV4:LiteInterpreterTest/*.BackPortByteCodeModelV4:*/LiteInterpreterTest.BackPortByteCodeModelV4/*:*/LiteInterpreterTest/*.BackPortByteCodeModelV4 --gtest_color=no
Testing started at 2:32 PM ...
CUDA not available. Disabling CUDA and MultiCUDA tests

[W backport.cpp:419] Warning: Backport doesn't support backport to version3 (function _backport_for_mobile_impl)
Process finished with exit code 0
```

## Test
1. Run both `caffe2/test/cpp/jit/test_lite_interpreter.cpp` and `caffe2/test/mobile/test_bytecode.py`.
2. Run all prod models with backport api.

Pull Request resolved: pytorch#56802

ghstack-source-id: 128425510

Test Plan: CI

Reviewed By: raziel, iseeyuan

Differential Revision: D27844651

fbshipit-source-id: 8a803cf6c76433ee0a3049b1a5570585d569f8d6
  • Loading branch information
cccclai authored and facebook-github-bot committed May 8, 2021
1 parent e9c3ce3 commit 8c04593
Show file tree
Hide file tree
Showing 13 changed files with 886 additions and 11 deletions.
2 changes: 2 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,8 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
if(NOT INTERN_BUILD_MOBILE AND NOT BUILD_LITE_INTERPRETER)
list(APPEND TORCH_SRCS
${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/backport.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/backport_manager.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/onnx.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp
Expand Down
99 changes: 99 additions & 0 deletions test/cpp/jit/test_lite_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/mobile/backport.h>
#include <torch/csrc/jit/mobile/backport_manager.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/model_compatibility.h>
#include <torch/csrc/jit/mobile/module.h>
Expand Down Expand Up @@ -631,6 +633,103 @@ TEST(LiteInterpreterTest, GetByteCodeVersion) {
AT_ASSERT(version_v4 == 4);
}

namespace {
void runAndCheckBytecodeModel(
std::stringstream& input_model_stream,
const std::vector<IValue>& input_data,
const std::vector<Tensor>& expect_result_list,
const int64_t expect_version) {
auto actual_version = _get_model_bytecode_version(input_model_stream);
AT_ASSERT(actual_version == expect_version);

// Load and run the backport model, then compare the result with expect
// result
mobile::Module m_mobile = _load_for_mobile(input_model_stream);

auto actual_result = m_mobile.forward(input_data);
std::vector<IValue> actual_result_list = actual_result.toTuple()->elements();

AT_ASSERT(actual_result_list.size() == expect_result_list.size());
AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0]));
AT_ASSERT(
actual_result_list[1].toTensor().dim() == expect_result_list[1].dim());
AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2]));
}

void backportAllVersionCheck(
std::stringstream& test_model_file_stream,
std::vector<IValue>& input_data,
std::vector<Tensor>& expect_result_list,
const int64_t expect_from_version) {
auto from_version = _get_model_bytecode_version(test_model_file_stream);
AT_ASSERT(from_version == expect_from_version);

// Backport script_module_v5.ptl to an older version
constexpr int64_t minimum_to_version = 4;
int64_t current_to_version = from_version - 1;

std::ostringstream oss;
// Verify all candidate to_version work as expected. All backport to version
// larger than minimum_to_version should success.
while (current_to_version >= minimum_to_version) {
oss.clear();
bool backPortSuccess =
_backport_for_mobile(test_model_file_stream, oss, current_to_version);
AT_ASSERT(backPortSuccess);

// Check backport model version
std::stringstream iss(oss.str());
auto backport_version = _get_model_bytecode_version(iss);
AT_ASSERT(backport_version == current_to_version);

// Load and run the backport model, then compare the result with expect
// result
runAndCheckBytecodeModel(
iss, input_data, expect_result_list, current_to_version);

current_to_version--;
}
// backport to minimum version - 1 should fail
oss.clear();
bool backPortSuccess =
_backport_for_mobile(test_model_file_stream, oss, minimum_to_version - 1);
AT_ASSERT(!backPortSuccess);
}

} // namespace

TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
torch::jit::Module module("m");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
module.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
module.register_parameter("bias", torch::ones({20}), false);
module.define(R"(
def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
return (x1, x2, x3)
)");

torch::jit::Module module_freeze = freeze(module);

std::stringstream input_model_stream;
module_freeze._save_for_mobile(input_model_stream);
std::vector<IValue> input_data =
std::vector<IValue>({torch::ones({1, 1, 28, 28})});
std::vector<Tensor> expect_result_list;
expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float) * 0);
expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float));
expect_result_list.emplace_back(
at::ones({1, 20, 24, 24}, ScalarType::Float) * 26);
backportAllVersionCheck(
input_model_stream,
input_data,
expect_result_list,
caffe2::serialize::kProducedBytecodeVersion);
}

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, SequentialModuleInfo) {
Module a("A");
Expand Down
248 changes: 243 additions & 5 deletions test/mobile/test_bytecode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from torch.jit.mobile import _get_model_bytecode_version
import fnmatch
import io
import shutil
import tempfile
import torch
import torch.utils.show_pickle
from torch.utils.mobile_optimizer import optimize_for_mobile
from torch.jit.mobile import (
_load_for_lite_interpreter,
_get_model_bytecode_version,
_backport_for_mobile_to_buffer,
_backport_for_mobile)
from torch.testing._internal.common_utils import TestCase, run_tests
from pathlib import Path

Expand All @@ -14,17 +25,244 @@
# increment = torch.ones([2, 4], dtype=torch.float64)
# return self.x + y + increment

# output_model_path = pathlib.Path(tmpdirname, "script_module_v5.ptl")
# output_model_path = Path(tmpdirname, "script_module_v5.ptl")
# script_module = torch.jit.script(TestModule(1))
# optimized_scripted_module = optimize_for_mobile(script_module)
# exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(
# str(output_model_path))

SCRIPT_MODULE_V4_BYTECODE_PKL = '''
(4,
('__torch__.*.TestModule.forward',
(('instructions',
(('STOREN', 1, 2),
('DROPR', 1, 0),
('LOADC', 0, 0),
('LOADC', 1, 0),
('MOVE', 2, 0),
('OP', 0, 0),
('LOADC', 1, 0),
('OP', 1, 0),
('RET', 0, 0))),
('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))),
('constants',
(torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),),
0,
(2, 4),
(4, 1),
False,
collections.OrderedDict()),
1)),
('types', ()),
('register_size', 2)),
(('arguments',
((('name', 'self'),
('type', '__torch__.*.TestModule'),
('default_value', None)),
(('name', 'y'), ('type', 'int'), ('default_value', None)))),
('returns',
((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
'''

SCRIPT_MODULE_V5_BYTECODE_PKL = '''
(5,
('__torch__.*.TestModule.forward',
(('instructions',
(('STOREN', 1, 2),
('DROPR', 1, 0),
('LOADC', 0, 0),
('LOADC', 1, 0),
('MOVE', 2, 0),
('OP', 0, 0),
('LOADC', 1, 0),
('OP', 1, 0),
('RET', 0, 0))),
('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))),
('constants',
(torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, 'constants/0', 'cpu', 8),),
0,
(2, 4),
(4, 1),
False,
collections.OrderedDict()),
1)),
('types', ()),
('register_size', 2)),
(('arguments',
((('name', 'self'),
('type', '__torch__.*.TestModule'),
('default_value', None)),
(('name', 'y'), ('type', 'int'), ('default_value', None)))),
('returns',
((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
'''

SCRIPT_MODULE_BYTECODE_PKL = {
4: {
"bytecode_pkl": SCRIPT_MODULE_V4_BYTECODE_PKL,
"model_name": "script_module_v4.ptl"},
}

# The minimum version a model can be backported to
# Need to be updated when a bytecode version is completely retired
MINIMUM_TO_VERSION = 4

class testVariousModelVersions(TestCase):
def test_get_model_bytecode_version(self):
script_module_v4 = pytorch_test_dri / "cpp" / "jit" / "script_module_v4.ptl"
version_v4 = _get_model_bytecode_version(script_module_v4)
assert(version_v4 == 4)
def check_model_version(model_path, expect_version):
actual_version = _get_model_bytecode_version(model_path)
assert(actual_version == expect_version)
for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items():
model_path = pytorch_test_dri / "cpp" / "jit" / model_info["model_name"]
check_model_version(model_path, version)

def test_bytecode_values_for_all_backport_functions(self):
# Find the maximum version of the checked in models, start backporting to the minimum support version,
# and comparing the bytecode pkl content.
# It can't be merged to the test `test_all_backport_functions`, because optimization is dynamic and
# the content might change when optimize function changes. This test focuses
# on bytecode.pkl content validation. For the content validation, it is not byte to byte check, but
# regular expression matching. The wildcard can be used to skip some specific content comparison.
maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
current_from_version = maximum_checked_in_model_version

with tempfile.TemporaryDirectory() as tmpdirname:
while current_from_version > MINIMUM_TO_VERSION:
# Load model v5 and run forward method
model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version]["model_name"]
input_model_path = pytorch_test_dri / "cpp" / "jit" / model_name

# A temporary model file will be export to this path, and run through bytecode.pkl
# content check.
tmp_output_model_path_backport = Path(tmpdirname, "tmp_script_module_backport.ptl")

current_to_version = current_from_version - 1
backport_success = _backport_for_mobile(input_model_path, tmp_output_model_path_backport, current_to_version)
assert(backport_success)

expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version]["bytecode_pkl"]

buf = io.StringIO()
torch.utils.show_pickle.main(
["", tmpdirname + "/" + tmp_output_model_path_backport.name + "@*/bytecode.pkl"],
output_stream=buf)
output = buf.getvalue()

acutal_result_clean = "".join(output.split())
expect_result_clean = "".join(expect_bytecode_pkl.split())
isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
assert(isMatch)

current_from_version -= 1
shutil.rmtree(tmpdirname)

def test_all_backport_functions(self):
# Backport from the latest bytecode version to the minimum support version
# Load, run the backport model, and check version
class TestModule(torch.nn.Module):
def __init__(self, v):
super().__init__()
self.x = v

def forward(self, y: int):
increment = torch.ones([2, 4], dtype=torch.float64)
return self.x + y + increment

module_input = 1
expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)

# temporary input model file and output model file will be exported in the temporary folder
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_input_model_path = Path(tmpdirname, "tmp_script_module.ptl")
script_module = torch.jit.script(TestModule(1))
optimized_scripted_module = optimize_for_mobile(script_module)
exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(str(tmp_input_model_path))

current_from_version = _get_model_bytecode_version(tmp_input_model_path)
current_to_version = current_from_version - 1
tmp_output_model_path = Path(tmpdirname, "tmp_script_module_backport.ptl")

while current_to_version >= MINIMUM_TO_VERSION:
# Backport the latest model to `to_version` to a tmp file "tmp_script_module_backport"
backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, current_to_version)
assert(backport_success)

backport_version = _get_model_bytecode_version(tmp_output_model_path)
assert(backport_version == current_to_version)

# Load model and run forward method
mobile_module = _load_for_lite_interpreter(str(tmp_input_model_path))
mobile_module_result = mobile_module(module_input)
torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)
current_to_version -= 1

# Check backport failure case
backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, MINIMUM_TO_VERSION - 1)
assert(not backport_success)
# need to clean the folder before it closes, otherwise will run into git not clean error
shutil.rmtree(tmpdirname)

# Check just the test_backport_bytecode_from_file_to_file mechanism but not the function implementations
def test_backport_bytecode_from_file_to_file(self):
maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
script_module_v5_path = pytorch_test_dri / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
maximum_checked_in_model_version]["model_name"]

if (maximum_checked_in_model_version > MINIMUM_TO_VERSION):
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_backport_model_path = Path(tmpdirname, "tmp_script_module_v5_backported_to_v4.ptl")
# backport from file
success = _backport_for_mobile(
script_module_v5_path,
tmp_backport_model_path,
maximum_checked_in_model_version - 1)
assert(success)

buf = io.StringIO()
torch.utils.show_pickle.main(
["", tmpdirname + "/" + tmp_backport_model_path.name + "@*/bytecode.pkl"],
output_stream=buf)
output = buf.getvalue()

expected_result = SCRIPT_MODULE_V4_BYTECODE_PKL
acutal_result_clean = "".join(output.split())
expect_result_clean = "".join(expected_result.split())
isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean)
assert(isMatch)

# Load model v4 and run forward method
mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path))
module_input = 1
mobile_module_result = mobile_module(module_input)
expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)
shutil.rmtree(tmpdirname)

# Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations
def test_backport_bytecode_from_file_to_buffer(self):
maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
script_module_v5_path = pytorch_test_dri / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
maximum_checked_in_model_version]["model_name"]

if (maximum_checked_in_model_version > MINIMUM_TO_VERSION):
# Backport model to v4
script_module_v4_buffer = _backport_for_mobile_to_buffer(
script_module_v5_path, maximum_checked_in_model_version - 1)
buf = io.StringIO()

# Check version of the model v4 from backport
bytesio = io.BytesIO(script_module_v4_buffer)
backport_version = _get_model_bytecode_version(bytesio)
assert(backport_version == maximum_checked_in_model_version - 1)

# Load model v4 from backport and run forward method
bytesio = io.BytesIO(script_module_v4_buffer)
mobile_module = _load_for_lite_interpreter(bytesio)
module_input = 1
mobile_module_result = mobile_module(module_input)
expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64)
torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)


if __name__ == '__main__':
run_tests()
2 changes: 2 additions & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
"torch/csrc/autograd/FunctionsManual.cpp",
"torch/csrc/jit/api/module_save.cpp",
"torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp",
"torch/csrc/jit/mobile/backport.cpp",
"torch/csrc/jit/mobile/backport_manager.cpp",
"torch/csrc/jit/mobile/export_data.cpp",
# To be included for eager symbolication in lite interpreter
# when it is built in libtorch
Expand Down
Loading

0 comments on commit 8c04593

Please sign in to comment.