Skip to content

Commit

Permalink
Fix bug in flatbuffer deserialization
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#78344

Approved by: https://github.com/qihqi
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed May 31, 2022
1 parent f42b42d commit 7e72c96
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions test/cpp/jit/test_flatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/export_bytecode.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
Expand Down Expand Up @@ -837,6 +838,7 @@ TEST(FlatbufferTest, DuplicateSetState) {

auto buff = save_mobile_module_to_bytes(bc);
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());
ASSERT_TRUE(torch::jit::is_upgraders_map_populated());
const auto methods2 = bc.get_methods();
ASSERT_EQ(methods2.size(), expected_n);
}
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>

#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/export_bytecode.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
Expand All @@ -14,6 +15,9 @@ Module parse_and_initialize_jit_module(
size_t size,
ExtraFilesMap& extra_files,
c10::optional<at::Device> device) {
#if ENABLE_UPGRADERS
populate_upgraders_graph_map();
#endif
auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
FlatbufferLoader loader;
mobile::Module mobilem = loader.parseModule(flatbuffer_module);
Expand Down

0 comments on commit 7e72c96

Please sign in to comment.