Skip to content

Commit

Permalink
[C++ API] Better forward methods (pytorch#8739)
Browse files Browse the repository at this point in the history
* Better forward methods in C++ API

capitalize error message in test_torch.test_flatten

Support for operator()

* Add operator() to Functional

* Get rid of SigmoidLinear

* Add BoundFunction to FunctionalImpl

* Remove macro from conv because it makes errors more nasty
  • Loading branch information
goldsborough committed Jun 26, 2018
1 parent f607794 commit 5575735
Show file tree
Hide file tree
Showing 28 changed files with 375 additions and 335 deletions.
67 changes: 35 additions & 32 deletions aten/src/ATen/ExpandUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,28 @@ std::vector<int64_t> infer_size(IntList a, IntList b) {
long dimB = dimsB - 1 - offset;
long sizeA = (dimA >= 0) ? a[dimA] : 1;
long sizeB = (dimB >= 0) ? b[dimB] : 1;
if (sizeA == sizeB || sizeA == 1 || sizeB == 1) {
expandedSizes[i] = std::max(sizeA, sizeB);
} else {
std::ostringstream oss;
oss << "The size of tensor a (" << sizeA << ") must match the size of tensor b ("
<< sizeB << ") at non-singleton dimension " << i;
throw std::runtime_error(oss.str());
}

AT_CHECK(
sizeA == sizeB || sizeA == 1 || sizeB == 1,
"The size of tensor a (", sizeA,
") must match the size of tensor b (", sizeB,
") at non-singleton dimension ", i);

expandedSizes[i] = std::max(sizeA, sizeB);
}

return expandedSizes;
}

std::tuple<std::vector<int64_t>, std::vector<int64_t> >
inferExpandGeometry(const Tensor &tensor, IntList sizes) {
std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
const Tensor& tensor,
IntList sizes) {
int64_t ndim = sizes.size();

if (tensor.dim() == 0) {
std::vector<int64_t> expandedStrides(ndim, 0);
return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(sizes.vec(), expandedStrides);
return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(
sizes.vec(), expandedStrides);
}
std::vector<int64_t> expandedSizes(ndim);
std::vector<int64_t> expandedStrides(ndim);
Expand All @@ -43,34 +45,35 @@ inferExpandGeometry(const Tensor &tensor, IntList sizes) {
int64_t offset = ndim - 1 - i;
int64_t dim = tensor.dim() - 1 - offset;
int64_t size = (dim >= 0) ? tensor.sizes()[dim] : 1;
int64_t stride = (dim >= 0) ?
tensor.strides()[dim] : expandedSizes[i + 1] * expandedStrides[i + 1];
int64_t stride = (dim >= 0) ? tensor.strides()[dim]
: expandedSizes[i + 1] * expandedStrides[i + 1];
int64_t targetSize = sizes[i];
if (targetSize == -1) {
if (dim < 0) {
std::ostringstream oss;
oss << "The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, "
<< "non-existing dimension " << i;
throw std::runtime_error(oss.str());
} else {
targetSize = size;
}
AT_CHECK(
dim >= 0,
"The expanded size of the tensor (",
targetSize,
") isn't allowed in a leading, non-existing dimension ",
i);
targetSize = size;
}
if (size != targetSize) {
if (size == 1) {
size = targetSize;
stride = 0;
} else {
std::ostringstream oss;
oss << "The expanded size of the tensor (" << targetSize << ") must match the existing size (" << size
<< ") at non-singleton dimension " << i;
throw std::runtime_error(oss.str());
}
AT_CHECK(
size == 1,
"The expanded size of the tensor (",
targetSize,
") must match the existing size (",
size,
") at non-singleton dimension ",
i);
size = targetSize;
stride = 0;
}
expandedSizes[i] = size;
expandedStrides[i] = stride;
}
return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(expandedSizes, expandedStrides);
return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(
expandedSizes, expandedStrides);
}

}
} // namespace at
10 changes: 4 additions & 6 deletions aten/src/ATen/WrapDimUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wr

int64_t min = -dim_post_expr;
int64_t max = dim_post_expr - 1;
if (dim < min || dim > max) {
std::ostringstream oss;
oss << "dimension out of range (expected to be in range of [" << min
<< ", " << max << "], but got " << dim << ")",
throw std::runtime_error(oss.str());
}
AT_CHECK(
dim >= min && dim <= max,
"Dimension out of range (expected to be in range of [",
min, ", ", max, "], but got ", dim, ")");
if (dim < 0) dim += dim_post_expr;
return dim;
}
Expand Down
42 changes: 20 additions & 22 deletions test/cpp/api/integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ bool test_mnist(
std::cout << "Num correct: " << correct.data().sum().toCFloat() << " out of "
<< telabel.size(0) << std::endl;
return correct.data().sum().toCFloat() > telabel.size(0) * 0.8;
};
}

TEST_CASE("integration/cartpole") {
std::cerr << "Training episodic policy gradient with a critic for up to 3000"
Expand All @@ -245,16 +245,16 @@ TEST_CASE("integration/cartpole") {
std::vector<torch::Tensor> saved_values;
std::vector<float> rewards;

auto forward = [&](std::vector<torch::Tensor> inp) {
auto x = linear->forward(inp)[0].clamp_min(0);
torch::Tensor actions = policyHead->forward({x})[0];
torch::Tensor value = valueHead->forward({x})[0];
auto forward = [&](torch::Tensor inp) {
auto x = linear->forward(inp).clamp_min(0);
torch::Tensor actions = policyHead->forward(x);
torch::Tensor value = valueHead->forward(x);
return std::make_tuple(at::softmax(actions, -1), value);
};

auto selectAction = [&](torch::Tensor state) {
// Only work on single state now, change index to gather for batch
auto out = forward({state});
auto selectAction = [&](at::Tensor state) {
// Only work on single state right now, change index to gather for batch
auto out = forward(state);
auto probs = torch::Tensor(std::get<0>(out));
auto value = torch::Tensor(std::get<1>(out));
auto action = probs.data().multinomial(1)[0].toCInt();
Expand Down Expand Up @@ -340,16 +340,15 @@ TEST_CASE("integration/mnist", "[cuda]") {
auto linear2 = model->add(Linear(50, 10), "linear2");

auto forward = [&](torch::Tensor x) {
x = std::get<0>(at::max_pool2d(conv1->forward({x})[0], {2, 2}))
.clamp_min(0);
x = conv2->forward({x})[0];
x = drop2d->forward({x})[0];
x = std::get<0>(at::max_pool2d(conv1->forward(x), {2, 2})).clamp_min(0);
x = conv2->forward(x);
x = drop2d->forward(x);
x = std::get<0>(at::max_pool2d(x, {2, 2})).clamp_min(0);

x = x.view({-1, 320});
x = linear1->forward({x})[0].clamp_min(0);
x = drop->forward({x})[0];
x = linear2->forward({x})[0];
x = linear1->forward(x).clamp_min(0);
x = drop->forward(x);
x = linear2->forward(x);
x = at::log_softmax(x, 1);
return x;
};
Expand Down Expand Up @@ -378,16 +377,15 @@ TEST_CASE("integration/mnist/batchnorm", "[cuda]") {
auto linear2 = model->add(Linear(50, 10), "linear2");

auto forward = [&](torch::Tensor x) {
x = std::get<0>(at::max_pool2d(conv1->forward({x})[0], {2, 2}))
.clamp_min(0);
x = batchnorm2d->forward({x})[0];
x = conv2->forward({x})[0];
x = std::get<0>(at::max_pool2d(conv1->forward(x), {2, 2})).clamp_min(0);
x = batchnorm2d->forward(x);
x = conv2->forward(x);
x = std::get<0>(at::max_pool2d(x, {2, 2})).clamp_min(0);

x = x.view({-1, 320});
x = linear1->forward({x})[0].clamp_min(0);
x = batchnorm1->forward({x})[0];
x = linear2->forward({x})[0];
x = linear1->forward(x).clamp_min(0);
x = batchnorm1->forward(x);
x = linear2->forward(x);
x = at::log_softmax(x, 1);
return x;
};
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/api/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ TEST_CASE("misc") {
torch::NoGradGuard guard;
Linear model(5, 2);
auto x = torch::randn({10, 5}, at::requires_grad());
auto y = model->forward({x})[0];
auto y = model->forward(x);
torch::Tensor s = y.sum();

s.backward();
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/api/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ TEST_CASE("module/training-mode") {
TEST_CASE("module/zero-grad") {
Linear module(3, 4);
auto weight = torch::ones({8, 3}, at::requires_grad());
auto loss = module->forward({weight}).front().sum();
auto loss = module->forward(weight).sum();
loss.backward();
for (auto& parameter : module->parameters()) {
auto grad = parameter->grad();
Expand Down
72 changes: 39 additions & 33 deletions test/cpp/api/modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ class TestModel : public torch::nn::Module {
l3 = register_module("l3", Linear(5, 100));
}

std::vector<torch::Tensor> forward(std::vector<torch::Tensor> input) {
return input;
}

Linear l1, l2, l3;
};

Expand All @@ -36,10 +32,6 @@ class NestedModel : public torch::nn::Module {
param_ = register_parameter("param", torch::empty({3, 2, 21}));
}

std::vector<torch::Tensor> forward(std::vector<torch::Tensor> input) {
return input;
};

torch::Tensor param_;
Linear l1;
std::shared_ptr<TestModel> t;
Expand All @@ -50,7 +42,7 @@ TEST_CASE("modules") {
SECTION("1d") {
Conv1d model(Conv1dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5}, at::requires_grad());
auto y = model->forward({x})[0];
auto y = model->forward(x);
torch::Tensor s = y.sum();

s.backward();
Expand All @@ -66,7 +58,7 @@ TEST_CASE("modules") {
SECTION("even") {
Conv2d model(Conv2dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5, 5}, at::requires_grad());
auto y = model->forward({x})[0];
auto y = model->forward(x);
torch::Tensor s = y.sum();

s.backward();
Expand All @@ -82,7 +74,7 @@ TEST_CASE("modules") {
SECTION("uneven") {
Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({2, 2}));
auto x = torch::randn({2, 3, 5, 4}, at::requires_grad());
auto y = model->forward({x})[0];
auto y = model->forward(x);
torch::Tensor s = y.sum();

s.backward();
Expand All @@ -98,7 +90,7 @@ TEST_CASE("modules") {
SECTION("3d") {
Conv3d model(Conv3dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5, 5, 5}, at::requires_grad());
auto y = model->forward({x})[0];
auto y = model->forward(x);
torch::Tensor s = y.sum();

s.backward();
Expand All @@ -116,7 +108,7 @@ TEST_CASE("modules") {
SECTION("basic1") {
Linear model(5, 2);
auto x = torch::randn({10, 5}, at::requires_grad());
auto y = model->forward({x})[0];
auto y = model->forward(x);
torch::Tensor s = y.sum();

s.backward();
Expand All @@ -136,9 +128,9 @@ TEST_CASE("modules") {
auto l3 = model->add(Linear(5, 100), "l3");

auto x = torch::randn({1000, 10}, at::requires_grad());
x = l1->forward({x})[0].clamp_min(0);
x = l2->forward({x})[0].clamp_min(0);
x = l3->forward({x})[0].clamp_min(0);
x = l1->forward(x).clamp_min(0);
x = l2->forward(x).clamp_min(0);
x = l3->forward(x).clamp_min(0);

x.backward();
REQUIRE(x.ndimension() == 2);
Expand All @@ -154,7 +146,7 @@ TEST_CASE("modules") {
// Cannot get gradients to change indices (input) - only for embedding
// params
auto x = torch::full({10}, dict_size - 1, torch::kInt64);
auto y = model->forward({x})[0];
auto y = model->forward(x);
torch::Tensor s = y.sum();

s.backward();
Expand All @@ -169,7 +161,7 @@ TEST_CASE("modules") {
SECTION("list") {
Embedding model(6, 4);
auto x = torch::full({2, 3}, 5, torch::kInt64);
auto y = model->forward({x})[0];
auto y = model->forward(x);
torch::Tensor s = y.sum();

s.backward();
Expand All @@ -183,7 +175,7 @@ TEST_CASE("modules") {
SECTION("dropout") {
Dropout dropout(0.5);
torch::Tensor x = torch::ones(100, at::requires_grad());
torch::Tensor y = dropout->forward({x})[0];
torch::Tensor y = dropout->forward(x);

y.backward();
REQUIRE(y.ndimension() == 1);
Expand All @@ -194,7 +186,7 @@ TEST_CASE("modules") {
// REQUIRE(y.sum().toCFloat() > 70); // Probably

dropout->eval();
y = dropout->forward({x})[0];
y = dropout->forward(x);
REQUIRE(y.data().sum().toCFloat() == 100);
}

Expand All @@ -219,17 +211,31 @@ TEST_CASE("modules") {
}

SECTION("functional") {
bool was_called = false;
// clang-format off
auto functional = Functional([&was_called](std::vector<torch::Tensor> input) {
was_called = true;
return input;
});
// clang-format on
auto output = functional->forward({torch::ones(5, at::requires_grad())});
REQUIRE(was_called);
REQUIRE(output.size() == 1);
REQUIRE(output.front().equal(torch::ones(5, at::requires_grad())));
{
bool was_called = false;
auto functional = Functional([&was_called](torch::Tensor input) {
was_called = true;
return input;
});
auto output = functional->forward(torch::ones(5, at::requires_grad()));
REQUIRE(was_called);
REQUIRE(output.equal(torch::ones(5, at::requires_grad())));

was_called = false;
output = functional(torch::ones(5, at::requires_grad()));
REQUIRE(was_called);
REQUIRE(output.equal(torch::ones(5, at::requires_grad())));
}
{
auto functional = Functional(at::relu);
REQUIRE(functional(torch::ones({})).data().toCFloat() == 1);
REQUIRE(functional(torch::ones({})).toCFloat() == 1);
REQUIRE(functional(torch::ones({}) * -1).toCFloat() == 0);
}
{
auto functional = Functional(at::elu, /*alpha=*/1, /*scale=*/0);
REQUIRE(functional(torch::ones({})).toCFloat() == 0);
}
}
}

Expand All @@ -238,7 +244,7 @@ TEST_CASE("modules_cuda", "[cuda]") {
Linear model(5, 2);
model->cuda();
auto x = torch::randn({10, 5}, at::device(at::kCUDA).requires_grad(true));
auto y = model->forward({x})[0];
auto y = model->forward(x);
torch::Tensor s = y.sum();

s.backward();
Expand All @@ -255,7 +261,7 @@ TEST_CASE("modules_cuda", "[cuda]") {
model->cuda();
model->cpu();
auto x = torch::randn({10, 5}, at::requires_grad());
auto y = model->forward({x})[0];
auto y = model->forward(x);
torch::Tensor s = y.sum();

s.backward();
Expand Down
Loading

0 comments on commit 5575735

Please sign in to comment.