forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Autogenerate ATen dispatch for JIT nodes
- Loading branch information
Showing
7 changed files
with
238 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import os | ||
import argparse | ||
from collections import defaultdict | ||
from tools.shared.module_loader import import_module | ||
from itertools import count | ||
from ..autograd.gen_variable_type import load_aten_declarations, CodeTemplate, write, \ | ||
FALLTHROUGH_RETURN_TYPES, FALLTHROUGH_FUNCTIONS, GENERATED_COMMENT | ||
|
||
template_path = os.path.join(os.path.dirname(__file__), 'templates') | ||
|
||
ATEN_DISPATCH_H = CodeTemplate.from_file(template_path + '/aten_dispatch.h') | ||
ATEN_DISPATCH_CPP = CodeTemplate.from_file(template_path + '/aten_dispatch.cpp') | ||
|
||
ATTR_METHOD_MAP = { | ||
'int64_t': 'i', | ||
'IntList': 'is', | ||
'Scalar': 't', | ||
'bool': 'i', | ||
'double': 'f', | ||
'std::array<bool, 2>': 'is', | ||
'std::array<bool, 3>': 'is', | ||
} | ||
|
||
TYPE_CASTS = { | ||
'std::array<bool, 2>': 'as_bool_array<2>', | ||
'std::array<bool, 3>': 'as_bool_array<3>', | ||
'Scalar': 'Scalar', | ||
'IntList': 'std::vector<int64_t>', | ||
} | ||
|
||
ATTR_ASSIGNMENT = CodeTemplate("""\ | ||
auto ${name} = ${type_cast}(node->${method}(stringToSymbol("${name}")));\ | ||
""") | ||
|
||
CALL_NAMESPACE = CodeTemplate("at::${name}(${args})") | ||
CALL_METHOD = CodeTemplate("vars[0].${name}(${args})") | ||
|
||
CONSTRUCTOR = CodeTemplate("""\ | ||
{"${descriptor}", [](Node *node) { | ||
${assignments} | ||
return TensorOp([=](const variable_list& vars) -> variable_list { | ||
return pack_list(${call}); | ||
}, "${name}", ${num_inputs}); | ||
}}, | ||
""") | ||
|
||
|
||
def is_jit_op(decl): | ||
return (not decl['api_name'].endswith('_') and | ||
not decl['name'].endswith('_out') and | ||
not decl['name'].endswith('_forward') and | ||
not any(arg['simple_type'] == 'Generator' for arg in decl['arguments']) and | ||
not any(arg['simple_type'] == 'SparseTensor' for arg in decl['arguments']) and | ||
not decl['return_type'] in FALLTHROUGH_RETURN_TYPES and | ||
not decl['name'] in FALLTHROUGH_FUNCTIONS) | ||
|
||
|
||
def gen_jit_dispatch(declarations, out): | ||
aten_decls = load_aten_declarations(declarations) | ||
jit_decls = [d for d in aten_decls if is_jit_op(d)] | ||
|
||
def is_tensor_arg(arg): | ||
return arg['simple_type'] in {'Tensor', 'TensorList'} | ||
|
||
ops = {} | ||
for decl in jit_decls: | ||
arguments = decl['arguments'] | ||
name = decl['name'] | ||
scalar_args = [arg for arg in arguments if not is_tensor_arg(arg)] | ||
|
||
# Descriptor is a unique identified for a particular overload of an op | ||
attr_names = sorted([arg['name'] for arg in scalar_args]) | ||
num_inputs = len(arguments) - len(scalar_args) | ||
descriptor = '-'.join([decl['name'], str(num_inputs)] + attr_names) | ||
|
||
# All scalar args need to be assigned, so they can be captured by a lambda | ||
assignments = [ATTR_ASSIGNMENT.substitute(type=arg['simple_type'], | ||
type_cast=TYPE_CASTS.get(arg['simple_type'], arg['simple_type']), | ||
name=arg['name'], | ||
method=ATTR_METHOD_MAP[arg['simple_type']]) | ||
for arg in scalar_args] | ||
|
||
# Generate the actuall ATen call. This gets a bit tricky because of | ||
# TensorList arguments, and functions that are only available as methods. | ||
if 'namespace' in decl['method_of']: | ||
if any(arg['simple_type'] == 'TensorList' for arg in arguments): | ||
assert sum(map(is_tensor_arg, arguments)) == 1 | ||
args = ['as_tensor_list(vars)' if is_tensor_arg(arg) else arg['name'] | ||
for arg in arguments] | ||
else: | ||
tensor_id = iter(count(start=0)) | ||
args = ['vars[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name'] | ||
for arg in arguments] | ||
call = CALL_NAMESPACE.substitute(name=name, args=args) | ||
else: | ||
tensor_id = iter(count(start=1)) | ||
args = ['vars[{}]'.format(next(tensor_id)) if is_tensor_arg(arg) else arg['name'] | ||
for arg in arguments[1:]] | ||
call = CALL_METHOD.substitute(name=name, args=args) | ||
|
||
constructor = CONSTRUCTOR.substitute(descriptor=descriptor, name=name, call=call, | ||
assignments=assignments, num_inputs=num_inputs) | ||
assert descriptor not in ops, descriptor | ||
ops[descriptor] = constructor | ||
|
||
# Sort the generated snippets to ensure that the generation is deterministic | ||
env = {'constructors': sorted(list(ops.values()))} | ||
write(out, 'aten_dispatch.h', ATEN_DISPATCH_H, env) | ||
write(out, 'aten_dispatch.cpp', ATEN_DISPATCH_CPP, env) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description='Generate JIT op dispatch') | ||
parser.add_argument('declarations', metavar='DECL', | ||
help='path to Declarations.yaml') | ||
parser.add_argument('out', metavar='OUT', | ||
help='path to output directory') | ||
args = parser.parse_args() | ||
gen_jit_dispatch(args.declarations, args.out) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#include "aten_dispatch.h" | ||
#include "torch/csrc/jit/interned_strings.h" | ||
#include "torch/csrc/utils/functional.h" | ||
|
||
#include <unordered_map> | ||
#include <cstring> | ||
|
||
// ${generated_comment} | ||
|
||
namespace torch { namespace jit { | ||
|
||
using autograd::Variable; | ||
using autograd::variable_list; | ||
using at::Scalar; | ||
using at::Tensor; | ||
using at::IntList; | ||
using at::TensorList; | ||
using operator_constructor = std::function<TensorOp(jit::Node*)>; | ||
|
||
namespace { | ||
|
||
variable_list pack_list(Tensor v) { return { std::move(v) }; } | ||
variable_list pack_list(Scalar v) { return { v.toTensor() }; } | ||
variable_list pack_list(std::vector<Tensor> t) { return fmap<Variable>(t); } | ||
variable_list pack_list(std::tuple<Tensor, Tensor> v) { | ||
return { std::move(std::get<0>(v)), std::move(std::get<1>(v)) }; | ||
} | ||
variable_list pack_list(std::tuple<Tensor, Tensor, Tensor> v) { | ||
return { std::get<0>(v), std::get<1>(v), std::get<2>(v) }; | ||
} | ||
|
||
std::vector<Tensor> as_tensor_list(const variable_list& vars) { | ||
return fmap(vars, [](Variable v) { return static_cast<Tensor>(v); }); | ||
} | ||
|
||
template<size_t N> | ||
std::array<bool, N> as_bool_array(const std::vector<int64_t>& vec) { | ||
std::array<bool, N> res; | ||
JIT_ASSERT(vec.size() == N); | ||
std::copy(vec.begin(), vec.end(), res.begin()); | ||
return res; | ||
} | ||
|
||
std::unordered_map<std::string, operator_constructor> constructors = { | ||
${constructors} | ||
}; | ||
|
||
std::string getDescriptor(jit::Node* n) { | ||
std::stringstream s; | ||
s << symbolToString(n->kind()) << "-" << n->inputs().size(); | ||
std::vector<const char*> attr_names = fmap(n->attributeNames(), &symbolToString); | ||
std::sort(attr_names.begin(), attr_names.end(), [](const char *a, const char *b) { | ||
return std::strcmp(a, b) < 0; | ||
}); | ||
for (const auto & name : attr_names) | ||
s << "-" << name; | ||
return s.str(); | ||
} | ||
|
||
} // anonymous namespace | ||
|
||
TensorOp getTensorOp(jit::Node* n) { | ||
auto signature = getDescriptor(n); | ||
try { | ||
return constructors.at(signature)(n); | ||
} catch (std::out_of_range &e) { | ||
throw std::runtime_error("Unsupported op descriptor: " + signature + ". " | ||
"File a bug report."); | ||
} | ||
}; | ||
|
||
}} // namespace torch::jit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#include "torch/csrc/jit/ir.h" | ||
#include "torch/csrc/autograd/function.h" | ||
|
||
#include <functional> | ||
|
||
// ${generated_comment} | ||
|
||
namespace torch { namespace jit { | ||
|
||
struct TensorOp { | ||
using op_type = std::function<autograd::variable_list(const autograd::variable_list&)>; | ||
|
||
TensorOp(op_type op, std::string name, size_t num_inputs) | ||
: op(op) | ||
, name(name) | ||
, num_inputs(num_inputs) {} | ||
|
||
const op_type op; | ||
const std::string name; | ||
const size_t num_inputs; | ||
}; | ||
|
||
TensorOp getTensorOp(jit::Node* n); | ||
|
||
}} // namespace torch::jit; |