Skip to content

Commit

Permalink
[C++ API] Cursors (pytorch#8190)
Browse files Browse the repository at this point in the history
* Add cursors to C++ API

* Small self nits

* s/struct/class

* Use more STL like names for cursors
  • Loading branch information
goldsborough committed Jun 11, 2018
1 parent 77660a9 commit de4e97e
Show file tree
Hide file tree
Showing 17 changed files with 877 additions and 157 deletions.
49 changes: 26 additions & 23 deletions test/cpp/api/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ TEST_CASE("containers") {
REQUIRE(y.size(i) == 2);
}

REQUIRE(model->parameters().at("weight").grad().numel() == 3 * 2 * 3);
REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3);
}
SECTION("2d") {
SECTION("even") {
Expand All @@ -79,7 +79,7 @@ TEST_CASE("containers") {
}

REQUIRE(
model->parameters().at("weight").grad().numel() == 3 * 2 * 3 * 3);
model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3);
}

SECTION("uneven") {
Expand All @@ -96,7 +96,7 @@ TEST_CASE("containers") {
}

REQUIRE(
model->parameters().at("weight").grad().numel() == 3 * 2 * 3 * 2);
model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 2);
}
}
SECTION("3d") {
Expand All @@ -113,7 +113,8 @@ TEST_CASE("containers") {
}

REQUIRE(
model->parameters().at("weight").grad().numel() == 3 * 2 * 3 * 3 * 3);
model->parameters()["weight"].grad().numel() ==
3 * 2 * 3 * 3 * 3);
}
}
SECTION("linear") {
Expand All @@ -129,7 +130,7 @@ TEST_CASE("containers") {
REQUIRE(y.size(0) == 10);
REQUIRE(y.size(1) == 2);

REQUIRE(model->parameters().at("weight").grad().numel() == 2 * 5);
REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
}
}

Expand Down Expand Up @@ -167,7 +168,8 @@ TEST_CASE("containers") {
REQUIRE(y.size(0) == 10);
REQUIRE(y.size(1) == 2);

REQUIRE(model->parameters().at("table").grad().numel() == 2 * dict_size);
REQUIRE(
model->parameters()["table"].grad().numel() == 2 * dict_size);
}

SECTION("list") {
Expand Down Expand Up @@ -204,21 +206,22 @@ TEST_CASE("containers") {

SECTION("param") {
auto model = std::make_shared<NestedModel>();
REQUIRE(model->param("param").size(0) == 3);
REQUIRE(model->param("param").size(1) == 2);
REQUIRE(model->param("param").size(2) == 21);
REQUIRE(model->param("l1.bias").size(0) == 20);
REQUIRE(model->param("l1.weight").size(0) == 20);
REQUIRE(model->param("l1.weight").size(1) == 5);
REQUIRE(model->param("test.l1.bias").size(0) == 3);
REQUIRE(model->param("test.l1.weight").size(0) == 3);
REQUIRE(model->param("test.l1.weight").size(1) == 10);
REQUIRE(model->param("test.l2.bias").size(0) == 5);
REQUIRE(model->param("test.l2.weight").size(0) == 5);
REQUIRE(model->param("test.l2.weight").size(1) == 3);
REQUIRE(model->param("test.l3.bias").size(0) == 100);
REQUIRE(model->param("test.l3.weight").size(0) == 100);
REQUIRE(model->param("test.l3.weight").size(1) == 5);
auto parameters = model->parameters();
REQUIRE(parameters["param"].size(0) == 3);
REQUIRE(parameters["param"].size(1) == 2);
REQUIRE(parameters["param"].size(2) == 21);
REQUIRE(parameters["l1.bias"].size(0) == 20);
REQUIRE(parameters["l1.weight"].size(0) == 20);
REQUIRE(parameters["l1.weight"].size(1) == 5);
REQUIRE(parameters["test.l1.bias"].size(0) == 3);
REQUIRE(parameters["test.l1.weight"].size(0) == 3);
REQUIRE(parameters["test.l1.weight"].size(1) == 10);
REQUIRE(parameters["test.l2.bias"].size(0) == 5);
REQUIRE(parameters["test.l2.weight"].size(0) == 5);
REQUIRE(parameters["test.l2.weight"].size(1) == 3);
REQUIRE(parameters["test.l3.bias"].size(0) == 100);
REQUIRE(parameters["test.l3.weight"].size(0) == 100);
REQUIRE(parameters["test.l3.weight"].size(1) == 5);
}

SECTION("functional") {
Expand Down Expand Up @@ -250,7 +253,7 @@ TEST_CASE("containers_cuda", "[cuda]") {
REQUIRE(y.size(0) == 10);
REQUIRE(y.size(1) == 2);

REQUIRE(model->parameters().at("weight").grad().numel() == 2 * 5);
REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
}

SECTION("2") {
Expand All @@ -267,6 +270,6 @@ TEST_CASE("containers_cuda", "[cuda]") {
REQUIRE(y.size(0) == 10);
REQUIRE(y.size(1) == 2);

REQUIRE(model->parameters().at("weight").grad().numel() == 2 * 5);
REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
}
}
Loading

0 comments on commit de4e97e

Please sign in to comment.