Skip to content

Commit

Permalink
Autogenerate ATen dispatch for JIT nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed Oct 26, 2017
1 parent 869bdeb commit 61afb0d
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/jit/generated/*
torch/csrc/autograd/generated/*
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
Expand Down
11 changes: 9 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def run(self):
from tools.cwrap.plugins.Broadcast import Broadcast
from tools.cwrap.plugins.ProcessorSpecificPlugin import ProcessorSpecificPlugin
from tools.autograd.gen_variable_type import gen_variable_type
from tools.jit.gen_jit_dispatch import gen_jit_dispatch
thp_plugin = THPPlugin()
cwrap('torch/csrc/generic/TensorMethods.cwrap', plugins=[
ProcessorSpecificPlugin(), BoolOption(), thp_plugin,
Expand All @@ -261,11 +262,16 @@ def run(self):
])
# Build ATen based Variable classes
autograd_gen_dir = 'torch/csrc/autograd/generated'
if not os.path.exists(autograd_gen_dir):
os.mkdir(autograd_gen_dir)
jit_gen_dir = 'torch/csrc/jit/generated'
for d in (autograd_gen_dir, jit_gen_dir):
if not os.path.exists(d):
os.mkdir(d)
gen_variable_type(
'torch/lib/build/ATen/ATen/Declarations.yaml',
autograd_gen_dir)
gen_jit_dispatch(
'torch/lib/build/ATen/ATen/Declarations.yaml',
jit_gen_dir)

# It's an old-style class in Python 2.7...
setuptools.command.build_ext.build_ext.run(self)
Expand Down Expand Up @@ -403,6 +409,7 @@ def check_file(f):
"torch/csrc/jit/passes/onnx.cpp",
"torch/csrc/jit/passes/dead_code_elimination.cpp",
"torch/csrc/jit/passes/common_subexpression_elimination.cpp",
"torch/csrc/jit/generated/aten_dispatch.cpp",
"torch/csrc/autograd/init.cpp",
"torch/csrc/autograd/engine.cpp",
"torch/csrc/autograd/function.cpp",
Expand Down
15 changes: 7 additions & 8 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""")

FUNCTION_DECLARATION = CodeTemplate("""\
struct ${op} : public Function {
using Function::Function;
struct ${op} : public TraceableFunction {
using TraceableFunction::TraceableFunction;
variable_list apply(const variable_list& grads) override;
std::string name() override { return "${op}"; }
void releaseVariables() override {
Expand Down Expand Up @@ -116,7 +116,7 @@

RECORD_TRACE = CodeTemplate("""\
if (jit::tracer::isTracing({ ${tensor_args} })) {
jit::Node *n = jit::tracer::recordTrace( "${api_name}", ${trace_inputs}, ${trace_outputs} );
jit::Node *n = jit::tracer::recordTrace( "${trace_name}", ${trace_inputs}, ${trace_outputs} );
${record_attributes}
}
""")
Expand Down Expand Up @@ -642,6 +642,10 @@ def emit_record_trace(env, declaration):
if not local['record_attributes']:
local['record_attributes'].append('(void)n;')

local['trace_name'] = declaration['api_name']
if local['trace_name'].endswith('_'):
local['trace_name'] = local['trace_name'][:-1]

combined = nested_dict(local, nested_dict(env, declaration))
return RECORD_TRACE.substitute(combined)

Expand Down Expand Up @@ -778,12 +782,10 @@ def load_aten_declarations(path):

# enrich declarations with additional information
for declaration in declarations:
args = []
for arg in declaration['arguments']:
simple_type = arg['type']
simple_type = simple_type.replace(' &', '').replace('const ', '')
simple_type = simple_type.replace('Generator *', 'Generator')
args.append(simple_type)
arg['simple_type'] = simple_type
declaration['formals'] = [arg['type'] + ' ' + arg['name']
for arg in declaration['arguments']]
Expand Down Expand Up @@ -864,9 +866,6 @@ def get_signature(name, params, call_args):
def gen_variable_type(declarations, out):
aten_decls = load_aten_declarations(declarations)

def by_name(option):
return option['name']

def group_declarations_by_signature():
d = defaultdict(list)
for declaration in aten_decls:
Expand Down
Empty file added tools/jit/__init__.py
Empty file.
124 changes: 124 additions & 0 deletions tools/jit/gen_jit_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import argparse
from collections import defaultdict
from tools.shared.module_loader import import_module
from itertools import count
from ..autograd.gen_variable_type import load_aten_declarations, CodeTemplate, write, \
FALLTHROUGH_RETURN_TYPES, FALLTHROUGH_FUNCTIONS, GENERATED_COMMENT

template_path = os.path.join(os.path.dirname(__file__), 'templates')

ATEN_DISPATCH_H = CodeTemplate.from_file(template_path + '/aten_dispatch.h')
ATEN_DISPATCH_CPP = CodeTemplate.from_file(template_path + '/aten_dispatch.cpp')

ATTR_METHOD_MAP = {
'int64_t': 'i',
'IntList': 'is',
'Scalar': 't',
'bool': 'i',
'double': 'f',
'std::array<bool, 2>': 'is',
'std::array<bool, 3>': 'is',
}

TYPE_CASTS = {
'std::array<bool, 2>': 'as_bool_array<2>',
'std::array<bool, 3>': 'as_bool_array<3>',
'Scalar': 'Scalar',
'IntList': 'std::vector<int64_t>',
}

ATTR_ASSIGNMENT = CodeTemplate("""\
auto ${name} = ${type_cast}(node->${method}(stringToSymbol("${name}")));\
""")

CALL_NAMESPACE = CodeTemplate("at::${name}(${args})")
CALL_METHOD = CodeTemplate("vars[0].${name}(${args})")

CONSTRUCTOR = CodeTemplate("""\
{"${descriptor}", [](Node *node) {
${assignments}
return TensorOp([=](const variable_list& vars) -> variable_list {
return pack_list(${call});
}, "${name}", ${num_inputs});
}},
""")


def is_jit_op(decl):
return (not decl['api_name'].endswith('_') and
not decl['name'].endswith('_out') and
not decl['name'].endswith('_forward') and
not any(arg['simple_type'] == 'Generator' for arg in decl['arguments']) and
not any(arg['simple_type'] == 'SparseTensor' for arg in decl['arguments']) and
not decl['return_type'] in FALLTHROUGH_RETURN_TYPES and
not decl['name'] in FALLTHROUGH_FUNCTIONS)


def gen_jit_dispatch(declarations, out):
aten_decls = load_aten_declarations(declarations)
jit_decls = [d for d in aten_decls if is_jit_op(d)]

def is_tensor_arg(arg):
return arg['simple_type'] in {'Tensor', 'TensorList'}

ops = {}
for decl in jit_decls:
arguments = decl['arguments']
name = decl['name']
scalar_args = [arg for arg in arguments if not is_tensor_arg(arg)]

# Descriptor is a unique identified for a particular overload of an op
attr_names = sorted([arg['name'] for arg in scalar_args])
num_inputs = len(arguments) - len(scalar_args)
descriptor = '-'.join([decl['name'], str(num_inputs)] + attr_names)

# All scalar args need to be assigned, so they can be captured by a lambda
assignments = [ATTR_ASSIGNMENT.substitute(type=arg['simple_type'],
type_cast=TYPE_CASTS.get(arg['simple_type'], arg['simple_type']),
name=arg['name'],
method=ATTR_METHOD_MAP[arg['simple_type']])
for arg in scalar_args]

# Generate the actuall ATen call. This gets a bit tricky because of
# TensorList arguments, and functions that are only available as methods.
if 'namespace' in decl['method_of']:
if any(arg['simple_type'] == 'TensorList' for arg in arguments):
assert sum(map(is_tensor_arg, arguments)) == 1
args = ['as_tensor_list(vars)' if is_tensor_arg(arg) else arg['name']
for arg in arguments]
else:
tensor_id = iter(count(start=0))
args = ['vars[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name']
for arg in arguments]
call = CALL_NAMESPACE.substitute(name=name, args=args)
else:
tensor_id = iter(count(start=1))
args = ['vars[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name']
for arg in arguments[1:]]
call = CALL_METHOD.substitute(name=name, args=args)

constructor = CONSTRUCTOR.substitute(descriptor=descriptor, name=name, call=call,
assignments=assignments, num_inputs=num_inputs)
assert descriptor not in ops, descriptor
ops[descriptor] = constructor

# Sort the generated snippets to ensure that the generation is deterministic
env = {'constructors': sorted(list(ops.values()))}
write(out, 'aten_dispatch.h', ATEN_DISPATCH_H, env)
write(out, 'aten_dispatch.cpp', ATEN_DISPATCH_CPP, env)


def main():
parser = argparse.ArgumentParser(
description='Generate JIT op dispatch')
parser.add_argument('declarations', metavar='DECL',
help='path to Declarations.yaml')
parser.add_argument('out', metavar='OUT',
help='path to output directory')
args = parser.parse_args()
gen_jit_dispatch(args.declarations, args.out)


if __name__ == '__main__':
main()
72 changes: 72 additions & 0 deletions tools/jit/templates/aten_dispatch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "aten_dispatch.h"
#include "torch/csrc/jit/interned_strings.h"
#include "torch/csrc/utils/functional.h"

#include <unordered_map>
#include <cstring>

// ${generated_comment}

namespace torch { namespace jit {

using autograd::Variable;
using autograd::variable_list;
using at::Scalar;
using at::Tensor;
using at::IntList;
using at::TensorList;
using operator_constructor = std::function<TensorOp(jit::Node*)>;

namespace {

variable_list pack_list(Tensor v) { return { std::move(v) }; }
variable_list pack_list(Scalar v) { return { v.toTensor() }; }
variable_list pack_list(std::vector<Tensor> t) { return fmap<Variable>(t); }
variable_list pack_list(std::tuple<Tensor, Tensor> v) {
return { std::move(std::get<0>(v)), std::move(std::get<1>(v)) };
}
variable_list pack_list(std::tuple<Tensor, Tensor, Tensor> v) {
return { std::get<0>(v), std::get<1>(v), std::get<2>(v) };
}

std::vector<Tensor> as_tensor_list(const variable_list& vars) {
return fmap(vars, [](Variable v) { return static_cast<Tensor>(v); });
}

template<size_t N>
std::array<bool, N> as_bool_array(const std::vector<int64_t>& vec) {
std::array<bool, N> res;
JIT_ASSERT(vec.size() == N);
std::copy(vec.begin(), vec.end(), res.begin());
return res;
}

std::unordered_map<std::string, operator_constructor> constructors = {
${constructors}
};

std::string getDescriptor(jit::Node* n) {
std::stringstream s;
s << symbolToString(n->kind()) << "-" << n->inputs().size();
std::vector<const char*> attr_names = fmap(n->attributeNames(), &symbolToString);
std::sort(attr_names.begin(), attr_names.end(), [](const char *a, const char *b) {
return std::strcmp(a, b) < 0;
});
for (const auto & name : attr_names)
s << "-" << name;
return s.str();
}

} // anonymous namespace

TensorOp getTensorOp(jit::Node* n) {
auto signature = getDescriptor(n);
try {
return constructors.at(signature)(n);
} catch (std::out_of_range &e) {
throw std::runtime_error("Unsupported op descriptor: " + signature + ". "
"File a bug report.");
}
};

}} // namespace torch::jit
25 changes: 25 additions & 0 deletions tools/jit/templates/aten_dispatch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/autograd/function.h"

#include <functional>

// ${generated_comment}

namespace torch { namespace jit {

struct TensorOp {
using op_type = std::function<autograd::variable_list(const autograd::variable_list&)>;

TensorOp(op_type op, std::string name, size_t num_inputs)
: op(op)
, name(name)
, num_inputs(num_inputs) {}

const op_type op;
const std::string name;
const size_t num_inputs;
};

TensorOp getTensorOp(jit::Node* n);

}} // namespace torch::jit;

0 comments on commit 61afb0d

Please sign in to comment.