Skip to content

Commit

Permalink
Toffee graph exporting for PyTorch.
Browse files Browse the repository at this point in the history
This commit adds a new exporter pass which takes a graph and returns
a string of the human-readable protobuf representation of a model.

We have two strategies for how conversions are implemented:

- If a Python autograd function has a primspec static method, we invoke
  it to get the Toffee conversion.  Use torch.toffee.op to generate the
  format expected to be returned.  The particular data representation is opaque
  and subject to change in the future.

- Otherwise, there's a giant if statement in the exporter, which manually
  uses the JIT IR C++ API and Toffee IR C++ protobuf API to convert.

You must check out a copy of the ToffeeIR repo
https://github.com/ProjectToffee/ToffeeIR at torch/lib; at the moment
we don't have a subtree/submodule set up.

Technical debt in this commit:

- To get protobuf headers in scope, we unconditionally add $CONDA_PREFIX/include
  to the include path.  This needs to be replaced with a more robust mechanism.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
  • Loading branch information
ezyang authored and soumith committed Sep 5, 2017
1 parent 890c207 commit dd58b14
Show file tree
Hide file tree
Showing 12 changed files with 451 additions and 71 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ addons:
- ubuntu-toolchain-r-test
packages:
- g++-5
- protobuf-compiler
- libprotobuf-dev

script:
- OMP_NUM_THREADS=2 ./test/run_test.sh
Expand Down
10 changes: 9 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ def run(self):
################################################################################

include_dirs = []

if os.getenv('CONDA_PREFIX'):
include_dirs.append(os.path.join(os.getenv('CONDA_PREFIX'), "include"))

library_dirs = []
extra_link_args = []
extra_compile_args = ['-std=c++11', '-Wno-write-strings',
Expand Down Expand Up @@ -293,6 +297,7 @@ def run(self):
tmp_install_path + "/include/THPP",
tmp_install_path + "/include/THNN",
tmp_install_path + "/include/ATen",
tmp_install_path + "/include/toffee",
]

library_dirs.append(lib_path)
Expand All @@ -308,6 +313,7 @@ def run(self):
ATEN_LIB = os.path.join(lib_path, 'libATen.so.1')
THD_LIB = os.path.join(lib_path, 'libTHD.so.1')
NCCL_LIB = os.path.join(lib_path, 'libnccl.so.1')
TOFFEE_LIB = os.path.join(lib_path, 'libtoffee.so.1')
if platform.system() == 'Darwin':
TH_LIB = os.path.join(lib_path, 'libTH.1.dylib')
THS_LIB = os.path.join(lib_path, 'libTHS.1.dylib')
Expand All @@ -319,14 +325,15 @@ def run(self):
ATEN_LIB = os.path.join(lib_path, 'libATen.1.dylib')
THD_LIB = os.path.join(lib_path, 'libTHD.1.dylib')
NCCL_LIB = os.path.join(lib_path, 'libnccl.1.dylib')
TOFFEE_LIB = os.path.join(lib_path, 'libtoffee.1.dylib')

if WITH_NCCL and (subprocess.call('ldconfig -p | grep libnccl >/dev/null', shell=True) == 0 or
subprocess.call('/sbin/ldconfig -p | grep libnccl >/dev/null', shell=True) == 0):
SYSTEM_NCCL = True

main_compile_args = ['-D_THP_CORE']
main_libraries = ['shm']
main_link_args = [TH_LIB, THS_LIB, THPP_LIB, THNN_LIB, ATEN_LIB]
main_link_args = [TH_LIB, THS_LIB, THPP_LIB, THNN_LIB, ATEN_LIB, TOFFEE_LIB]
main_sources = [
"torch/csrc/PtrWrapper.cpp",
"torch/csrc/Module.cpp",
Expand All @@ -346,6 +353,7 @@ def run(self):
"torch/csrc/jit/init.cpp",
"torch/csrc/jit/ir.cpp",
"torch/csrc/jit/graph_fuser.cpp",
"torch/csrc/jit/graph_exporter.cpp",
"torch/csrc/jit/init_pass.cpp",
"torch/csrc/jit/dead_code_elimination.cpp",
"torch/csrc/jit/test_jit.cpp",
Expand Down
27 changes: 0 additions & 27 deletions test/expect/TestJit.test_alexnet.expect

This file was deleted.

35 changes: 35 additions & 0 deletions test/expect/TestJit.test_export.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
node {
input: "1"
input: "2"
output: "4"
op_type: "Add"
}
node {
input: "1"
input: "4"
output: "6"
op_type: "Mul"
}
node {
input: "6"
output: "8"
op_type: "TanH"
}
node {
input: "8"
output: "10"
op_type: "Sigmoid"
}
node {
input: "10"
output: "12"
op_type: "Scale"
attribute {
name: "scale"
f: -1
}
}
name: "torch-jit-export"
input: "1"
input: "2"
output: "12"
53 changes: 10 additions & 43 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ def test_simple(self):

self.assertExpected(str(trace))

def test_export(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)

trace, (x, y) = torch._C._tracer_enter((x, y))
z = -torch.sigmoid(torch.tanh(x * (x + y)))
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace)
self.assertExpected(torch._C._jit_pass_export(trace))

def test_lstm(self):
# Careful: don't use fused backend (enabled with CUDA)
# Pasted from test_LSTM_cell
Expand Down Expand Up @@ -109,49 +119,6 @@ def test_traced_module(self):
out2 = lstm(input, (hx, cx))
self.assertEqual(out, out2)

@unittest.skip("in-place is not supported")
def test_alexnet(self):

class AlexNet(nn.Module):

def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x

model = torch.jit.traced(AlexNet())
x = Variable(torch.randn(10, 3, 224, 224), requires_grad=True)
trace, _ = model(x)
self.assertExpected(str(trace))

def test_autograd_closure(self):
a = x = Variable(torch.Tensor([0.4]), requires_grad=True)
b = y = Variable(torch.Tensor([0.7]), requires_grad=True)
Expand Down
19 changes: 19 additions & 0 deletions torch/autograd/_functions/basic_ops.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import torch
import torch.toffee
from ..function import Function, InplaceFunction
from .utils import maybe_unexpand, maybe_unexpand_or_view
import math


class Add(InplaceFunction):

@staticmethod
def primspec(a, b, inplace=False):
if inplace:
return None
return torch.toffee.op("Add", a, b)

@staticmethod
def forward(ctx, a, b, inplace=False):
ctx.a_size = a.size()
Expand Down Expand Up @@ -40,6 +47,12 @@ def backward(ctx, grad_output):

class Mul(Function):

@staticmethod
def primspec(a, b, inplace=False):
if inplace:
return None
return torch.toffee.op("Mul", a, b)

@staticmethod
def forward(ctx, a, b):
ctx.a_size = a.size()
Expand Down Expand Up @@ -215,6 +228,12 @@ def backward(ctx, grad_output):

class Negate(InplaceFunction):

@staticmethod
def primspec(i, inplace=False):
if inplace:
return None
return torch.toffee.op("Scale", i, scale=float(-1))

@staticmethod
def forward(ctx, i, inplace=False):
if inplace:
Expand Down
14 changes: 14 additions & 0 deletions torch/autograd/_functions/pointwise.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

from itertools import repeat

from ..._thnn import type2backend
Expand Down Expand Up @@ -52,6 +54,12 @@ def backward(ctx, grad_output):

class Tanh(InplaceFunction):

@staticmethod
def primspec(i, inplace=False):
if inplace:
return None
return torch.toffee.op("TanH", i)

@staticmethod
def forward(ctx, i, inplace=False):
if inplace:
Expand All @@ -77,6 +85,12 @@ def backward(ctx, grad_output):

class Sigmoid(InplaceFunction):

@staticmethod
def primspec(i, inplace=False):
if inplace:
return None
return torch.toffee.op("Sigmoid", i)

@staticmethod
def forward(ctx, i, inplace=False):
if inplace:
Expand Down
Loading

0 comments on commit dd58b14

Please sign in to comment.