forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Created DefaultTensorOptions in ATen (pytorch#8647)
* Created DefaultTensorOptions * Fix TensorOptions() call which was interpreted as function decl * Fix empty OptionsGuard * Make options_ and mutex_ in DefaultTensorOptions class static because of dynamic linker issues * Make DefaultOptions thread local
- Loading branch information
1 parent
521f511
commit a5df8ec
Showing
7 changed files
with
190 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#include <ATen/OptionsGuard.h> | ||
#include <ATen/optional.h> | ||
|
||
namespace at { | ||
|
||
thread_local at::optional<TensorOptions> DefaultTensorOptions::options_; | ||
|
||
TensorOptions& DefaultTensorOptions::get() { | ||
if (!options_) { | ||
options_.emplace( | ||
/*use_thread_local_default_options=*/false); | ||
} | ||
return *options_; | ||
} | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#pragma once | ||
|
||
#include <ATen/Device.h> | ||
#include <ATen/Layout.h> | ||
#include <ATen/ScalarType.h> | ||
#include <ATen/TensorOptions.h> | ||
#include <ATen/optional.h> | ||
|
||
namespace at { | ||
|
||
/// A wrapper over a thread local TensorOptions instance. | ||
struct DefaultTensorOptions { | ||
/// Returns the current thread local default options. | ||
/// Defined in OptionsGuard.cpp because we can't use optional in headers, due | ||
/// to Windows and other compilers. | ||
static TensorOptions& get(); | ||
|
||
private: | ||
/// This is an optional because of compiler bugs that mis-initialize static | ||
/// thread local variables. The workaround is lazy initialization, i.e. | ||
/// `DefaultTensorOptions::get()` will initialize the `options_` to a proper | ||
/// value upon first invocation. | ||
/// https://gcc.gnu.org/ml/gcc-bugs/2013-12/msg00026.html | ||
static thread_local at::optional<TensorOptions> options_; | ||
}; | ||
|
||
/// RAII guard that stores the current default options upon construction, sets | ||
/// the current default options to the ones given to its constructor, and | ||
/// finally resets the options back to the original ones in the destructor. | ||
struct OptionsGuard { | ||
/// Stores the current default options and sets them to the given ones. | ||
explicit OptionsGuard(const TensorOptions& options) | ||
: original_(DefaultTensorOptions::get()) { | ||
DefaultTensorOptions::get() = options; | ||
} | ||
|
||
/// Restores the original default options. | ||
~OptionsGuard() { | ||
DefaultTensorOptions::get() = original_; | ||
} | ||
|
||
/// Returns the original options that were in place at the time of | ||
/// construction of this object. | ||
const TensorOptions& original() { | ||
return original_; | ||
} | ||
|
||
private: | ||
/// The original options that were in place at the time of construction of | ||
/// this object. | ||
TensorOptions original_; | ||
}; | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#include <ATen/TensorOptions.h> | ||
|
||
#include <ATen/Device.h> | ||
#include <ATen/Layout.h> | ||
#include <ATen/OptionsGuard.h> | ||
#include <ATen/ScalarType.h> | ||
#include <ATen/optional.h> | ||
|
||
namespace at { | ||
|
||
TensorOptions::TensorOptions(bool use_thread_local_default_options) { | ||
if (use_thread_local_default_options) { | ||
this->dtype(DefaultTensorOptions::get().dtype()); | ||
this->device(DefaultTensorOptions::get().device()); | ||
this->layout(DefaultTensorOptions::get().layout()); | ||
this->requires_grad(DefaultTensorOptions::get().requires_grad()); | ||
} | ||
} | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters