Skip to content

Commit

Permalink
A few additions (pytorch#9837)
Browse files Browse the repository at this point in the history
Summary:
This PR provides 4 fixes / features:

1. torch::nn::Cloneable inherits virtually from torch::nn::Module. We want to pass around a module with new functions, and the best way to do this is to do a diamond inheritance pattern, i.e.

```c++
struct MySuperModuleImpl : virtual public torch::nn::Module {
  virtual void myFunction() = 0;
}

struct MySuperModule : public torch::nn::Cloneable<MySuperModule>, MySuperModuleImple {};

struct MyModule : public MySuperModule<MyModule> {
  void myFunction() override;
};
```

This way, we can simply pass around MySuperModuleImpl around instead of torch::nn::Module.

2. Optimizer options are public now, since there's no way to decay the LR or modify it during training otherwise
3. Serialization functions creates autograd history and calls copy_! Bad!
4. Optimizers did not create buffers after add_parameters was called.
Pull Request resolved: pytorch#9837

Reviewed By: goldsborough

Differential Revision: D9199746

Pulled By: ebetica

fbshipit-source-id: 76d6b22e589a42637b7cc0b5bcd3c6b6662fb299
  • Loading branch information
ebetica authored and facebook-github-bot committed Aug 13, 2018
1 parent 0a39a9c commit b8530dc
Show file tree
Hide file tree
Showing 21 changed files with 141 additions and 120 deletions.
2 changes: 1 addition & 1 deletion test/cpp/api/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ TEST_CASE("module/clone") {
a->module->weight.data() += 1;
a->module->value = 123;

auto b = std::static_pointer_cast<NestedModule>(a->clone());
auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());

REQUIRE(!pointer_equal(b->module->weight, a->module->weight));
REQUIRE(
Expand Down
22 changes: 21 additions & 1 deletion test/cpp/api/optim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ bool test_optimizer_xor(Options options) {
const int64_t kBatchSize = 4;
const int64_t kMaximumNumberOfEpochs = 3000;

auto optimizer = OptimizerClass(model->parameters(), options);
auto optimizer = OptimizerClass(std::vector<torch::Tensor>(), options);
optimizer.add_parameters(model->parameters());

float running_loss = 1;
int epoch = 0;
Expand Down Expand Up @@ -258,3 +259,22 @@ TEST_CASE("Optim/ExternalVectorOfParameters") {
REQUIRE(parameters[1].allclose(original_parameters[1] - 1.0));
REQUIRE(parameters[2].allclose(original_parameters[2] - 1.0));
}

TEST_CASE("Optim/AddParameter/LBFGS") {
torch::manual_seed(0);

std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
std::vector<torch::Tensor> original_parameters = {parameters[0].clone()};

// Set all gradients to one
for (auto& parameter : parameters) {
parameter.grad() = torch::ones_like(parameter);
}

LBFGS optimizer(std::vector<torch::Tensor>(), 1.0);
optimizer.add_parameters(parameters);

optimizer.step([]() { return torch::tensor(1); });

// REQUIRE this doesn't throw
}
4 changes: 2 additions & 2 deletions test/cpp/api/sequential.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ TEST_CASE("sequential") {
SECTION("Is cloneable") {
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
Sequential clone =
std::static_pointer_cast<SequentialImpl>(sequential->clone());
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
REQUIRE(sequential->size() == clone->size());

for (size_t i = 0; i < sequential->size(); ++i) {
Expand Down Expand Up @@ -309,7 +309,7 @@ TEST_CASE("sequential/clone-to-device", "[cuda]") {
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
torch::Device device(torch::kCUDA, 0);
Sequential clone =
std::static_pointer_cast<SequentialImpl>(sequential->clone(device));
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
for (const auto& p : clone->parameters()) {
REQUIRE(p->device() == device);
}
Expand Down
22 changes: 15 additions & 7 deletions test/cpp/api/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ TEST_CASE("serialization") {
ss.seekg(0, std::ios::beg);
torch::load(ss, model3.get());

auto param1 = model1->parameters();
auto param2 = model2->parameters();
auto param3 = model3->parameters();
for (const auto& p : param1) {
REQUIRE(param1[p.key].allclose(param2[p.key]));
REQUIRE(param2[p.key].allclose(param3[p.key]));
}

// Make some optimizers with momentum (and thus state)
auto optim1 = torch::optim::SGD(
model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
Expand All @@ -240,9 +248,9 @@ TEST_CASE("serialization") {
auto optim3_2 = torch::optim::SGD(
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));

auto x = torch::ones({10, 5}, torch::requires_grad());
auto x = torch::ones({10, 5});

auto step = [&](torch::optim::Optimizer& optimizer, Linear model) {
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
optimizer.zero_grad();
auto y = model->forward(x).sum();
y.backward();
Expand All @@ -264,11 +272,11 @@ TEST_CASE("serialization") {
torch::load(ss, &optim3_2);
step(optim3_2, model3);

auto param1 = model1->parameters();
auto param2 = model2->parameters();
auto param3 = model3->parameters();
for (auto& p : param1) {
auto& name = p.key;
param1 = model1->parameters();
param2 = model2->parameters();
param3 = model3->parameters();
for (const auto& p : param1) {
const auto& name = p.key;
// Model 1 and 3 should be the same
REQUIRE(param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
REQUIRE(param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/api/include/torch/nn/cloneable.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace nn {
/// `clone()` method. We do not want to use this pattern in the base class,
/// because then storing a module would always require templatizing it.
template <typename Derived>
class Cloneable : public Module {
class Cloneable : public virtual Module {
public:
using Module::Module;

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/api/include/torch/nn/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ std::shared_ptr<ModuleType> Module::register_module(
std::string name,
std::shared_ptr<ModuleType> module) {
auto& base_module = children_.insert(std::move(name), std::move(module));
return std::static_pointer_cast<ModuleType>(base_module);
return std::dynamic_pointer_cast<ModuleType>(base_module);
}

template <typename ModuleType>
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/api/include/torch/nn/modules/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ struct AnyModule::Holder : public AnyModule::Placeholder {
std::unique_ptr<Placeholder> clone(
at::optional<Device> device) const override {
return torch::make_unique<Holder>(
std::static_pointer_cast<ModuleType>(module->clone(device)));
std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
}

/// The actual concrete module instance.
Expand Down
10 changes: 4 additions & 6 deletions torch/csrc/api/include/torch/optim/adagrad.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ class Adagrad : public Optimizer {
ParameterContainer&& parameters,
const AdagradOptions& options)
: Optimizer(std::forward<ParameterContainer>(parameters)),
options_(options),
options(options),
sum_(zero_buffers_like(parameters_)),
step_(parameters_.size(), 0) {}

void step() override;

const AdagradOptions& options() const noexcept;
AdagradOptions options;

template <class Archive>
void serialize(Archive& ar) {
Expand All @@ -45,12 +45,10 @@ class Adagrad : public Optimizer {

private:
friend class cereal::access;
Adagrad() : options_(0) {}

AdagradOptions options_;
Adagrad() : options(0) {}

std::vector<Tensor> sum_;
std::vector<double> step_;
std::vector<int64_t> step_;
};
} // namespace optim
} // namespace torch
10 changes: 4 additions & 6 deletions torch/csrc/api/include/torch/optim/adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ class Adam : public Optimizer {
template <typename ParameterContainer>
explicit Adam(ParameterContainer&& parameters, const AdamOptions& options)
: Optimizer(std::forward<ParameterContainer>(parameters)),
options_(options),
options(options),
step_buffers_(parameters_.size(), 0),
exp_average_buffers_(zero_buffers_like(parameters_)),
exp_average_sq_buffers_(zero_buffers_like(parameters_)) {
if (options_.amsgrad_) {
if (options.amsgrad_) {
max_exp_average_sq_buffers_ = zero_buffers_like(parameters_);
}
}
Expand All @@ -49,13 +49,11 @@ class Adam : public Optimizer {
CEREAL_NVP(max_exp_average_sq_buffers_));
}

const AdamOptions& options() const noexcept;
AdamOptions options;

private:
friend class cereal::access;
Adam() : options_(0) {}

AdamOptions options_;
Adam() : options(0) {}

std::vector<int64_t> step_buffers_;
std::vector<Tensor> exp_average_buffers_;
Expand Down
12 changes: 5 additions & 7 deletions torch/csrc/api/include/torch/optim/lbfgs.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class LBFGS : public LossClosureOptimizer {
template <typename ParameterContainer>
explicit LBFGS(ParameterContainer&& parameters, const LBFGSOptions& options)
: LossClosureOptimizer(std::forward<ParameterContainer>(parameters)),
options_(options),
ro(options_.history_size_),
al(options_.history_size_) {}
options(options),
ro(options.history_size_),
al(options.history_size_) {}

torch::Tensor step(LossClosure closure) override;

const LBFGSOptions& options() const noexcept;
LBFGSOptions options;

template <class Archive>
void serialize(Archive& ar) {
Expand All @@ -52,13 +52,11 @@ class LBFGS : public LossClosureOptimizer {

private:
friend class cereal::access;
LBFGS() : options_(0) {}
LBFGS() : options(0) {}

at::Tensor gather_flat_grad();
void add_grad(const torch::Scalar& step_size, const at::Tensor& update);

LBFGSOptions options_;

at::Tensor d{torch::empty({0})};
at::Tensor H_diag{torch::empty({0})};
at::Tensor prev_flat_grad{torch::empty({0})};
Expand Down
20 changes: 20 additions & 0 deletions torch/csrc/api/include/torch/optim/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/nn/cursor.h>
#include <torch/tensor.h>

#include <algorithm>
#include <functional>
#include <memory>
#include <vector>
Expand Down Expand Up @@ -64,9 +65,28 @@ class OptimizerBase {
return result;
}

/// Accesses a buffer at the given index.
/// Additionally, zeros out the buffers when this is called on the index
template<typename T>
T& buffer_at(std::vector<T>& buffers, size_t index) {
if (buffers.size() <= index) {
const auto old_size = buffers.size();
buffers.resize(index + 1);
std::fill(buffers.begin() + old_size, buffers.end(), T{0});
}
return buffers[index];
}

/// Accesses a buffer at the given index, converts it to the type of the
/// parameter at the corresponding index (a no-op if they match).
/// Additionally, zeros out the buffers when this is called on the index
Tensor& buffer_at(std::vector<Tensor>& buffers, size_t index) {
if (buffers.size() <= index) {
for (auto i = buffers.size(); i <= index; i++) {
buffers.push_back(torch::zeros_like(parameters_.at(i)));
}
}
// Copy the buffer to the device and dtype of the parameter.
const auto& parameter = parameters_.at(index);
const auto& buffer = buffers.at(index);
if (buffer.device() != parameter.device() ||
Expand Down
8 changes: 3 additions & 5 deletions torch/csrc/api/include/torch/optim/rmsprop.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class RMSprop : public Optimizer {
ParameterContainer&& parameters,
const RMSpropOptions& options)
: Optimizer(std::forward<ParameterContainer>(parameters)),
options_(options),
options(options),
square_average_buffers_(zero_buffers_like(parameters_)) {
if (options.centered_ > 0) {
grad_average_buffers_ = zero_buffers_like(parameters_);
Expand All @@ -45,7 +45,7 @@ class RMSprop : public Optimizer {

void step() override;

const RMSpropOptions& options() const noexcept;
RMSpropOptions options;

template <class Archive>
void serialize(Archive& ar) {
Expand All @@ -56,9 +56,7 @@ class RMSprop : public Optimizer {

private:
friend class cereal::access;
RMSprop() : options_(0) {}

RMSpropOptions options_;
RMSprop() : options(0) {}

std::vector<Tensor> square_average_buffers_;
std::vector<Tensor> momentum_buffers_;
Expand Down
9 changes: 4 additions & 5 deletions torch/csrc/api/include/torch/optim/sgd.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class SGD : public Optimizer {
template <typename ParameterContainer>
explicit SGD(ParameterContainer&& parameters, const SGDOptions& options)
: Optimizer(std::forward<ParameterContainer>(parameters)),
options_(options) {
if (options_.momentum_ > 0) {
options(options) {
if (options.momentum_ > 0) {
momentum_buffers_ = zero_buffers_like(parameters_);
}
}
Expand All @@ -43,13 +43,12 @@ class SGD : public Optimizer {
ar(CEREAL_NVP(momentum_buffers_));
}

const SGDOptions& options() const noexcept;
SGDOptions options;

private:
friend class cereal::access;
SGD() : options_(0) {}
SGD() : options(0) {}

SGDOptions options_;
std::vector<Tensor> momentum_buffers_;
/// Counts how often `step()` is called, for dampening.
size_t iteration_{0};
Expand Down
11 changes: 7 additions & 4 deletions torch/csrc/api/include/torch/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <torch/tensor.h>
#include <torch/optim.h>
#include <torch/utils.h>

#include "cereal/archives/binary.hpp"
#include "cereal/types/polymorphic.hpp"
Expand Down Expand Up @@ -168,12 +169,13 @@ loadBinary(BinaryInputArchive& archive, void* data, size_t size) {
// Gradients will not be saved for variables
template <class Archive>
void save(Archive& archive, torch::Tensor const& tensor) {
torch::NoGradGuard guard;
if (!tensor.defined()) {
int32_t typeId = ::torch::detail::scalarTypeId(torch::Dtype::Undefined);
archive(CEREAL_NVP(typeId));
return;
} else {
int32_t typeId = ::torch::detail::scalarTypeId(tensor.data().type().scalarType());
int32_t typeId = ::torch::detail::scalarTypeId(tensor.dtype());
archive(CEREAL_NVP(typeId));
}
auto sizes = std::vector<int64_t>();
Expand All @@ -199,6 +201,7 @@ void save(Archive& archive, torch::Tensor const& tensor) {
**/
template <class Archive>
void load(Archive& archive, torch::Tensor& tensor) {
torch::NoGradGuard guard;
torch::Dtype type;
int32_t typeId;
archive(CEREAL_NVP(typeId));
Expand All @@ -214,19 +217,19 @@ void load(Archive& archive, torch::Tensor& tensor) {
archive(CEREAL_NVP(backendId), CEREAL_NVP(sizes));

at::Backend backend = ::torch::detail::backendFromId(backendId);
if (!tensor.defined() || tensor.data().type().scalarType() != type) {
if (!tensor.defined() || tensor.dtype() != type) {
tensor = torch::empty({}, torch::getType(backend, type));
}
tensor.data().resize_(sizes);

if (tensor.type().is_cuda()) {
// should actually use cudamemcpy probably
auto cputensor = torch::empty(sizes, tensor.data().type().scalarType());
auto cputensor = torch::empty(sizes, tensor.dtype());
agimpl::loadBinary(
archive,
cputensor.data_ptr(),
cputensor.numel() * cputensor.type().elementSizeInBytes());
tensor.copy_(cputensor);
tensor.data().copy_(cputensor.data());
} else {
agimpl::loadBinary(
archive,
Expand Down
Loading

0 comments on commit b8530dc

Please sign in to comment.