Skip to content

Commit

Permalink
[vulkan] Add nonVarTypeModeGuard to vulkan tests and speed_benchmark_…
Browse files Browse the repository at this point in the history
…torch (pytorch#52535)

Summary: Pull Request resolved: pytorch#52535

Test Plan: Imported from OSS

Reviewed By: ailzhang

Differential Revision: D26580994

Pulled By: SS-JIA

fbshipit-source-id: 94f091432265cf6607b73c34846c07273d47c70b
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Feb 25, 2021
1 parent e94940b commit c2558b4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/test/vulkan_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ TEST(VulkanAPITest, adaptive_avg_pool2d) {
if (!at::is_vulkan_available()) {
return;
}
at::AutoNonVariableTypeMode nonVarTypeModeGuard(true);

const auto in_cpu = at::rand({5, 7, 47, 31}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
const auto out_cpu = at::adaptive_avg_pool2d(in_cpu, {3, 3});
Expand Down Expand Up @@ -617,6 +618,7 @@ TEST(VulkanAPITest, reshape) {
if (!at::is_vulkan_available()) {
return;
}
at::AutoNonVariableTypeMode nonVarTypeModeGuard(true);

const auto in_cpu = at::rand({47, 11, 83, 97}, at::device(at::kCPU).dtype(at::kFloat));
const auto in_vulkan = in_cpu.vulkan();
Expand All @@ -638,6 +640,7 @@ TEST(VulkanAPITest, reshape_) {
if (!at::is_vulkan_available()) {
return;
}
at::AutoNonVariableTypeMode nonVarTypeModeGuard(true);

const auto cpu = at::rand({59, 41, 19, 67}, at::device(at::kCPU).dtype(at::kFloat));
const auto vulkan = cpu.vulkan();
Expand Down
1 change: 1 addition & 0 deletions binaries/compare_models_torch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ int main(int argc, char** argv) {
float tolerance = 0;
ss >> tolerance;

at::AutoNonVariableTypeMode nonVarTypeModeGuard(true);
torch::autograd::AutoGradMode guard(false);
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false);
auto module = torch::jit::load(FLAGS_model);
Expand Down
1 change: 1 addition & 0 deletions binaries/speed_benchmark_torch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ int main(int argc, char** argv) {

std::vector<c10::IValue> inputs = create_inputs();

at::AutoNonVariableTypeMode nonVarTypeModeGuard(true);
torch::autograd::AutoGradMode guard(false);
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false);
auto module = torch::jit::load(FLAGS_model);
Expand Down

0 comments on commit c2558b4

Please sign in to comment.