Skip to content

Commit

Permalink
Created DefaultTensorOptions in ATen (pytorch#8647)
Browse files Browse the repository at this point in the history
* Created DefaultTensorOptions

* Fix TensorOptions() call which was interpreted as function decl

* Fix empty OptionsGuard

* Make options_ and mutex_ in DefaultTensorOptions class static because of dynamic linker issues

* Make DefaultOptions thread local
  • Loading branch information
goldsborough committed Jun 25, 2018
1 parent 521f511 commit a5df8ec
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 8 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/ATen.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
#include "ATen/DeviceGuard.h"
#include "ATen/TensorOptions.h"
#include "ATen/Layout.h"
#include "ATen/OptionsGuard.h"
16 changes: 16 additions & 0 deletions aten/src/ATen/OptionsGuard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <ATen/OptionsGuard.h>
#include <ATen/optional.h>

namespace at {

thread_local at::optional<TensorOptions> DefaultTensorOptions::options_;

TensorOptions& DefaultTensorOptions::get() {
if (!options_) {
options_.emplace(
/*use_thread_local_default_options=*/false);
}
return *options_;
}

} // namespace at
54 changes: 54 additions & 0 deletions aten/src/ATen/OptionsGuard.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#pragma once

#include <ATen/Device.h>
#include <ATen/Layout.h>
#include <ATen/ScalarType.h>
#include <ATen/TensorOptions.h>
#include <ATen/optional.h>

namespace at {

/// A wrapper over a thread local TensorOptions instance.
struct DefaultTensorOptions {
/// Returns the current thread local default options.
/// Defined in OptionsGuard.cpp because we can't use optional in headers, due
/// to Windows and other compilers.
static TensorOptions& get();

private:
/// This is an optional because of compiler bugs that mis-initialize static
/// thread local variables. The workaround is lazy initialization, i.e.
/// `DefaultTensorOptions::get()` will initialize the `options_` to a proper
/// value upon first invocation.
/// https://gcc.gnu.org/ml/gcc-bugs/2013-12/msg00026.html
static thread_local at::optional<TensorOptions> options_;
};

/// RAII guard that stores the current default options upon construction, sets
/// the current default options to the ones given to its constructor, and
/// finally resets the options back to the original ones in the destructor.
struct OptionsGuard {
/// Stores the current default options and sets them to the given ones.
explicit OptionsGuard(const TensorOptions& options)
: original_(DefaultTensorOptions::get()) {
DefaultTensorOptions::get() = options;
}

/// Restores the original default options.
~OptionsGuard() {
DefaultTensorOptions::get() = original_;
}

/// Returns the original options that were in place at the time of
/// construction of this object.
const TensorOptions& original() {
return original_;
}

private:
/// The original options that were in place at the time of construction of
/// this object.
TensorOptions original_;
};

} // namespace at
19 changes: 19 additions & 0 deletions aten/src/ATen/TensorOptions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <ATen/TensorOptions.h>

#include <ATen/Device.h>
#include <ATen/Layout.h>
#include <ATen/OptionsGuard.h>
#include <ATen/ScalarType.h>
#include <ATen/optional.h>

namespace at {

TensorOptions::TensorOptions(bool use_thread_local_default_options) {
if (use_thread_local_default_options) {
this->dtype(DefaultTensorOptions::get().dtype());
this->device(DefaultTensorOptions::get().device());
this->layout(DefaultTensorOptions::get().layout());
this->requires_grad(DefaultTensorOptions::get().requires_grad());
}
}
} // namespace at
15 changes: 10 additions & 5 deletions aten/src/ATen/TensorOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@ namespace at {
/// `type()` to return a variable type instead of a tensor type, such that
/// variables are created inside factory methods, instead of tensors.
struct TensorOptions {
/// Constructs the `TensorOptions` with valid defaults, which are:
/// - dtype: float
/// - device: CPU
/// - layout: strided
TensorOptions() : TensorOptions(/*use_thread_local_default_options=*/true) {}

/// Constructs the `TensorOptions` with defaults taken from the thread local
/// `TensorOptions` object if `use_thread_local_default_options`, else
/// defaults to:
/// - dtype: kFloat,
/// - device: kCPU,
/// - layout: kStrided,
/// - requires_grad: false
TensorOptions() = default;
explicit TensorOptions(bool use_thread_local_default_options);

/// Constructs the `TensorOptions` from the type of the given `Tensor`.
/// If the `Tensor` has a CUDA type, the `device_index` will match that of the
Expand Down Expand Up @@ -192,6 +196,7 @@ struct TensorOptions {
return backend;
}

private:
ScalarType dtype_{kFloat};
Device device_{Device::Type::CPU};
Layout layout_{Layout::Strided};
Expand Down
37 changes: 36 additions & 1 deletion test/cpp/api/tensor_options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <ATen/Context.h>
#include <ATen/Functions.h>
#include <ATen/OptionsGuard.h>
#include <ATen/TensorOptions.h>

#include <vector>
Expand All @@ -18,6 +19,12 @@ using namespace at;
REQUIRE(options.dtype() == (type_)); \
REQUIRE(options.layout() == (layout_))

#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
REQUIRE(tensor.device().type() == Device((device_), (index_)).type()); \
REQUIRE(tensor.device().index() == Device((device_), (index_)).index()); \
REQUIRE(tensor.type().scalarType() == (type_)); \
REQUIRE(tensor.type().layout() == (layout_))

TEST_CASE("TensorOptions/DefaultsToTheRightValues") {
TensorOptions options;
REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
Expand Down Expand Up @@ -46,7 +53,7 @@ TEST_CASE("TensorOptions/UtilityFunctionsReturnTheRightTensorOptions") {
}

TEST_CASE("TensorOptions/ConstructsWellFromCPUTypes") {
auto options = TensorOptions();
TensorOptions options;
REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);

options = TensorOptions({kCPU, 0});
Expand Down Expand Up @@ -99,3 +106,31 @@ TEST_CASE("Device/ParsesCorrectlyFromString") {
REQUIRE_THROWS(Device(badness));
}
}

TEST_CASE("OptionsGuard") {
Tensor tensor;
{
OptionsGuard guard(TensorOptions{});
tensor = at::empty({10});
}
REQUIRE_TENSOR_OPTIONS(kCPU, -1, kFloat, kStrided);

{
OptionsGuard guard(TensorOptions().dtype(kInt));
tensor = at::empty({10});
}
REQUIRE_TENSOR_OPTIONS(kCPU, -1, kInt, kStrided);

{
OptionsGuard guard(TensorOptions().dtype(kInt).layout(kSparse));
tensor = at::empty({10});
}
REQUIRE_TENSOR_OPTIONS(kCPU, -1, kInt, kSparse);

{
OptionsGuard guard(requires_grad(true));
tensor = torch::empty({10});
}
REQUIRE_TENSOR_OPTIONS(kCPU, -1, kFloat, kStrided);
REQUIRE(tensor.requires_grad());
}
56 changes: 54 additions & 2 deletions test/cpp/api/tensor_options_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "catch.hpp"

#include <ATen/Context.h>
#include <ATen/DeviceGuard.h>
#include <ATen/Functions.h>
#include <ATen/OptionsGuard.h>
#include <ATen/TensorOptions.h>

#include <ATen/DeviceGuard.h>

using namespace at;

// A macro so we don't lose location information when an assertion fails.
Expand All @@ -15,6 +15,12 @@ using namespace at;
REQUIRE(options.dtype() == (type_)); \
REQUIRE(options.layout() == (layout_))

#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
REQUIRE(tensor.device().type() == Device((device_), (index_)).type()); \
REQUIRE(tensor.device().index() == Device((device_), (index_)).index()); \
REQUIRE(tensor.type().scalarType() == (type_)); \
REQUIRE(tensor.type().layout() == (layout_))

TEST_CASE("TensorOptions/ConstructsWellFromCUDATypes", "[cuda]") {
auto options = TensorOptions(CUDA(kFloat));
REQUIRE_OPTIONS(kCUDA, -1, kFloat, kStrided);
Expand Down Expand Up @@ -59,3 +65,49 @@ TEST_CASE("TensorOptions/ConstructsWellFromCUDATensors", "[cuda]") {
REQUIRE_OPTIONS(kCUDA, 1, kFloat, kSparse);
}
}

TEST_CASE("OptionsGuardCUDA", "[cuda]") {
Tensor tensor;
{
OptionsGuard guard(device(kCUDA));
tensor = at::empty({10});
}
REQUIRE_TENSOR_OPTIONS(kCUDA, 0, kFloat, kStrided);

{
OptionsGuard guard(device({kCUDA, 1}));
tensor = at::empty({10});
}
REQUIRE_TENSOR_OPTIONS(kCUDA, 1, kFloat, kStrided);

{
OptionsGuard guard(device(kCUDA).dtype(kInt));
tensor = at::empty({10});
}
REQUIRE_TENSOR_OPTIONS(kCUDA, 0, kInt, kStrided);
}

TEST_CASE("DeviceGuardOptionsGuardInteraction", "[cuda]") {
Tensor tensor;
{
// Check that OptionsGuard respects any active device before construction.
DeviceGuard guard(1);
{
OptionsGuard guard(device(kCUDA));
tensor = at::empty({10});
REQUIRE_TENSOR_OPTIONS(kCUDA, 1, kFloat, kStrided);
{
// Check that OptionsGuard respects any active device after
// construction.
DeviceGuard guard(0);
tensor = at::empty({10});
REQUIRE_TENSOR_OPTIONS(kCUDA, 0, kFloat, kStrided);
{
OptionsGuard guard(device({kCUDA, 1}));
tensor = at::empty({10});
REQUIRE_TENSOR_OPTIONS(kCUDA, 1, kFloat, kStrided);
}
}
}
}
}

0 comments on commit a5df8ec

Please sign in to comment.