Skip to content

Commit

Permalink
[LT] Store OpKind for each IR subclass in a static field
Browse files Browse the repository at this point in the history
Summary: Currently OpKind is stored as an object field called op_ for each IR
node, and one usage of op_ is to avoid dynamic_cast in NodeCast when we
need to downcast a base-node pointer into a concrete sub-node pointer.
As a result, we need to construct and pass in an op when downcasting
nodes, and this becomes quite anonnying when we start to implement the
trie-based IR node reusing. More importantly, the op for each subclass
should be unique for that subclass and thus making it a const static field
is a more logical design.

In this PR, we still keep the object-level op_ for easier XLA adoption. As
furture work, we can come back to remove op_, make the op() method
virtual, and get rid of OpKind in all the node constructors.

Pull Request resolved: pytorch#76711

Approved by: https://github.com/wconstab, https://github.com/JackCaoG
  • Loading branch information
desertfire authored and pytorchmergebot committed May 6, 2022
1 parent 8b6a78f commit ac37ddc
Show file tree
Hide file tree
Showing 56 changed files with 212 additions and 63 deletions.
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ libtorch_cpp_generated_sources = [
"torch/csrc/autograd/generated/Functions.cpp",
"torch/csrc/autograd/generated/variable_factories.h",
"torch/csrc/lazy/generated/LazyIr.h",
"torch/csrc/lazy/generated/LazyIr.cpp",
"torch/csrc/lazy/generated/LazyNativeFunctions.h",
"torch/csrc/lazy/generated/LazyNativeFunctions.cpp",
"torch/csrc/lazy/generated/RegisterAutogradLazy.cpp",
Expand Down Expand Up @@ -1914,6 +1915,7 @@ test_suite(
for path in [
"aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
"aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
"aten/src/ATen/templates/LazyIr.cpp",
"aten/src/ATen/templates/LazyIr.h",
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
"aten/src/ATen/native/native_functions.yaml",
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/templates/LazyIr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// ${generated_comment}
${includes}

${namespace_prologue}

${opkind_definitions}

${namespace_epilogue}
2 changes: 2 additions & 0 deletions build.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def define_targets(rules):
srcs = [
":DispatchKeyNativeFunctions.cpp",
":DispatchKeyNativeFunctions.h",
":LazyIr.cpp",
":LazyIr.h",
":RegisterDispatchKey.cpp",
":native_functions.yaml",
Expand Down Expand Up @@ -111,6 +112,7 @@ _GENERATED_CPP = [
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
"torch/csrc/autograd/generated/python_variable_methods.cpp",
"torch/csrc/lazy/generated/LazyIr.cpp",
"torch/csrc/lazy/generated/LazyNativeFunctions.cpp",
"torch/csrc/lazy/generated/RegisterAutogradLazy.cpp",
"torch/csrc/lazy/generated/RegisterLazy.cpp",
Expand Down
2 changes: 2 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
)
if(BUILD_LAZY_TS_BACKEND)
list(APPEND GENERATED_CXX_TORCH
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyIr.cpp"
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNativeFunctions.cpp"
"${TORCH_SRC_DIR}/csrc/lazy/generated/RegisterAutogradLazy.cpp"
"${TORCH_SRC_DIR}/csrc/lazy/generated/RegisterLazy.cpp"
Expand Down Expand Up @@ -432,6 +433,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h"
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.cpp"
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
"${TOOLS_PATH}/autograd/templates/VariableType.h"
"${TOOLS_PATH}/autograd/templates/VariableType.cpp"
Expand Down
8 changes: 6 additions & 2 deletions test/cpp/lazy/test_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ namespace lazy {

class TestLeafNode : public Node {
public:
static const OpKind class_op_kind;

explicit TestLeafNode(size_t param)
: Node(OpKind(), /* num_outputs */ 1),
hash_(Hash(param)),
Expand All @@ -38,14 +40,16 @@ class TestLeafNode : public Node {
size_t param_;
};

const OpKind TestLeafNode::class_op_kind = OpKind();

TEST(IrTest, BasicTest) {
NodePtr node1 = MakeNode<TestLeafNode>(1);
NodePtr node2 = MakeNode<TestLeafNode>(2);
EXPECT_NE(node1->hash(), node2->hash());

EXPECT_EQ(node1->num_outputs(), 1);

const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get(), OpKind());
const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get());
EXPECT_TRUE(leafptr != nullptr);
}

Expand Down Expand Up @@ -102,7 +106,7 @@ TEST(IrTest, TsNodeTest) {

EXPECT_EQ(node1->num_outputs(), 1);

const TsNode* leafptr = NodeCast<TsNode>(node1.get(), OpKind(at::aten::view));
const TsNode* leafptr = dynamic_cast<const TsNode*>(node1.get());
EXPECT_TRUE(leafptr != nullptr);
}

Expand Down
30 changes: 17 additions & 13 deletions test/cpp/lazy/test_trie_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ namespace lazy {

class TrieCacheNode : public Node {
public:
static const OpKind class_op_kind;

explicit TrieCacheNode(size_t id)
: Node(OpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
: Node(class_op_kind, /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
~TrieCacheNode() override = default;

bool Equal(size_t id) const {
Expand All @@ -36,6 +38,8 @@ class TrieCacheNode : public Node {
hash_t hash_;
};

const OpKind TrieCacheNode::class_op_kind = OpKind();

TEST(TrieCacheTest, TestSinglePath) {
FLAGS_torch_lazy_reuse_ir = true;
TrieCache::Get()->Clear();
Expand All @@ -45,9 +49,9 @@ TEST(TrieCacheTest, TestSinglePath) {
NodePtr c = MakeNode<TrieCacheNode>(2);
TrieCache::Get()->ResetCurrent(); // MarkStep

EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
}

Expand All @@ -67,20 +71,20 @@ TEST(TrieCacheTest, TestTwoPaths) {
NodePtr c = MakeNode<TrieCacheNode>(2);
TrieCache::Get()->ResetCurrent(); // MarkStep

EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3);
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(3);
EXPECT_NE(d.get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep

EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3).get(), d.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(3).get(), d.get());
TrieCache::Get()->ResetCurrent(); // MarkStep

EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
TrieCache::Get()->ResetCurrent(); // MarkStep
}

Expand Down
2 changes: 2 additions & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

# This is duplicated in caffe2/CMakeLists.txt for now and not yet used in buck
GENERATED_LAZY_TS_CPP = [
"lazy/generated/LazyIr.cpp",
"lazy/generated/LazyNativeFunctions.cpp",
"lazy/generated/RegisterAutogradLazy.cpp",
"lazy/generated/RegisterLazy.cpp",
Expand Down Expand Up @@ -425,6 +426,7 @@ lazy_tensor_ts_sources = [
"torch/csrc/lazy/ts_backend/ops/expand.cpp",
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
"torch/csrc/lazy/ts_backend/ops/scalar.cpp",
"torch/csrc/lazy/ts_backend/ops/to_copy.cpp",
"torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp",
"torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp",
"torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp",
Expand Down
14 changes: 14 additions & 0 deletions torch/csrc/lazy/core/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ inline std::ostream& operator<<(std::ostream& stream, const Node& node) {
return stream;
}

// Note: Keep this version of NodeCast for smooth PyTorch/XLA migration, and
// clean up once the migration is done.
template <typename T>
const T* NodeCast(const Node* node, OpKind op) {
if (op != node->op()) {
Expand All @@ -187,6 +189,18 @@ const T* NodeCast(const Node* node, OpKind op) {
#endif
}

template <typename T>
const T* NodeCast(const Node* node) {
if (T::class_op_kind != node->op()) {
return nullptr;
}
#ifdef NDEBUG
return static_cast<const T*>(node);
#else
return &dynamic_cast<const T&>(*node);
#endif
}


// Represents a specific output produced by a node. Since the output of a node
// can be composed by multiple outputs, the node+index coordinates fully qualify
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/lazy/core/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace torch {
namespace lazy {

template <typename T, typename... Args>
NodePtr ReuseNode(OpKind op, Args&&... args) {
NodePtr ReuseNode(Args&&... args) {
if (FLAGS_torch_lazy_reuse_ir) {
return LookupNodeFromTrieCache<T>(op, std::forward<Args>(args)...);
return LookupNodeFromTrieCache<T>(std::forward<Args>(args)...);
}
return nullptr;
}
Expand All @@ -27,16 +27,16 @@ template <typename T, typename... Args>
NodePtr MakeNode(Args&&... args) {
NodePtr node = std::make_shared<T>(std::forward<Args>(args)...);
if (FLAGS_torch_lazy_reuse_ir) {
// If ir caching is enabled, we need to record all new nodes
// If ir caching is enabled, we need to record all new nodes
TrieCache::Get()->Insert(node);
}
return node;
}

// op is passed in for a more efficient node casting, see the implementation of NodeCast
template <typename T, typename... Args>
NodePtr ReuseOrMakeNode(OpKind op, Args&&... args) {
NodePtr node = ReuseNode<T>(op, std::forward<Args>(args)...);
NodePtr ReuseOrMakeNode(Args&&... args) {
NodePtr node = ReuseNode<T>(std::forward<Args>(args)...);
if (!node) {
node = MakeNode<T>(std::forward<Args>(args)...);
}
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/lazy/core/trie.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ class TORCH_API TrieCache {
};

template <typename T, typename... Args>
NodePtr LookupNodeFromTrieCache(OpKind op, Args&&... args) {
NodePtr LookupNodeFromTrieCache(Args&&... args) {
auto& successors = TrieCache::Get()->Current()->successors;
for (auto it = successors.begin(); it != successors.end(); it++) {
NodePtr ir_node = (*it)->ir_node;
const T* concrete_node = NodeCast<T>(ir_node.get(), op);
const T* concrete_node = NodeCast<T>(ir_node.get());
if (concrete_node && concrete_node->Equal(std::forward<Args>(args)...)) {
TORCH_LAZY_COUNTER("IrNodeReused::" + std::string(typeid(T).name()), 1);
TrieCache::Get()->SetCurrent(it);
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
namespace torch {
namespace lazy {

const OpKind TSNativeBatchNormBackward::class_op_kind(at::aten::native_batch_norm_backward);
const OpKind TSNativeBatchNormForward::class_op_kind(at::aten::native_batch_norm);

TSNativeBatchNormBackward::TSNativeBatchNormBackward(
const torch::lazy::Value& grad_out, const torch::lazy::Value& input,
const torch::lazy::Value& weight, const torch::lazy::Value& running_mean,
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ namespace lazy {
// Node for the backward batch norm operator.
class TSNativeBatchNormBackward : public torch::lazy::TsNode {
public:
static const OpKind class_op_kind;

TSNativeBatchNormBackward(const torch::lazy::Value& grad_out, const torch::lazy::Value& input,
const torch::lazy::Value& weight, const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var, const torch::lazy::Value& save_mean,
Expand Down Expand Up @@ -35,6 +37,8 @@ class TSNativeBatchNormBackward : public torch::lazy::TsNode {

class TSNativeBatchNormForward : public torch::lazy::TsNode {
public:
static const OpKind class_op_kind;

TSNativeBatchNormForward(const torch::lazy::Value& input, const torch::lazy::Value& weight,
const torch::lazy::Value& bias, const torch::lazy::Value& running_mean,
const torch::lazy::Value& running_var, bool training,
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Shape NodeOutputShape(const Value& input, c10::ScalarType type) {
}

} // namespace

const OpKind Cast::class_op_kind(ltc_cast);

Cast::Cast(
const Value& input,
at::ScalarType dtype,
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace lazy {

class TORCH_API Cast : public TsNode {
public:
static const OpKind class_op_kind;

Cast(
const Value& input,
at::ScalarType dtype,
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/lazy/ts_backend/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
namespace torch {
namespace lazy {

const OpKind DeviceData::class_op_kind(ltc_device_data);

DeviceData::DeviceData(std::shared_ptr<BackendData> data)
: TsNode(
ltc_device_data,
Expand All @@ -22,7 +24,7 @@ std::string DeviceData::ToString() const {
}

const DeviceData* DeviceData::Cast(const Node* node) {
return NodeCast<DeviceData>(node, ltc_device_data);
return NodeCast<DeviceData>(node);
}

} // namespace lazy
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/device_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ namespace lazy {

class TORCH_API DeviceData : public TsNode {
public:
static const OpKind class_op_kind;

explicit DeviceData(std::shared_ptr<BackendData> data);

std::string ToString() const override;
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
namespace torch {
namespace lazy {

const OpKind Expand::class_op_kind(at::aten::expand);

Expand::Expand(
const Value& input,
std::vector<int64_t> size,
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/expand.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace lazy {

class TORCH_API Expand : public TsNode {
public:
static const OpKind class_op_kind;

Expand(const Value& input, std::vector<int64_t> size, bool is_scalar_expand);

std::string ToString() const override;
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/random_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
namespace torch {
namespace lazy {

const OpKind Normal::class_op_kind(c10::Symbol::fromQualString("aten::normal_"));

Normal::Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector<torch::lazy::Shape>&& shapes)
: torch::lazy::TsNode(torch::lazy::OpKind(c10::Symbol::fromQualString("aten::normal_")),
{self}, std::move(shapes),
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/random_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace lazy {

class Normal : public torch::lazy::TsNode {
public:
static const OpKind class_op_kind;

Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector<torch::lazy::Shape>&& shapes);

std::string ToString() const override;
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ namespace lazy {

using at::operator<<;

const OpKind Scalar::class_op_kind(at::prim::Constant);

Scalar::Scalar(const at::Scalar& value, Shape shape)
: TsNode(
OpKind(at::prim::Constant),
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace lazy {
// computation graph.
class TORCH_API Scalar : public TsNode {
public:
static const OpKind class_op_kind;

Scalar(const at::Scalar& value, Shape shape);
Scalar(const at::Scalar& value, c10::ScalarType type);

Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/lazy/ts_backend/ops/to_copy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include <torch/csrc/lazy/ts_backend/ops/to_copy.h>

namespace torch {
namespace lazy {

const OpKind ToCopy::class_op_kind(at::aten::_to_copy);

} // namespace lazy
} // namespace torch
Loading

0 comments on commit ac37ddc

Please sign in to comment.