From c2558b4b616dd7caec5e8cc83b57bb14995f118b Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 25 Feb 2021 14:21:26 -0800 Subject: [PATCH] [vulkan] Add nonVarTypeModeGuard to vulkan tests and speed_benchmark_torch (#52535) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52535 Test Plan: Imported from OSS Reviewed By: ailzhang Differential Revision: D26580994 Pulled By: SS-JIA fbshipit-source-id: 94f091432265cf6607b73c34846c07273d47c70b --- aten/src/ATen/test/vulkan_api_test.cpp | 3 +++ binaries/compare_models_torch.cc | 1 + binaries/speed_benchmark_torch.cc | 1 + 3 files changed, 5 insertions(+) diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index a5cb1d526d623..241096bb0edd0 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -70,6 +70,7 @@ TEST(VulkanAPITest, adaptive_avg_pool2d) { if (!at::is_vulkan_available()) { return; } + at::AutoNonVariableTypeMode nonVarTypeModeGuard(true); const auto in_cpu = at::rand({5, 7, 47, 31}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); const auto out_cpu = at::adaptive_avg_pool2d(in_cpu, {3, 3}); @@ -617,6 +618,7 @@ TEST(VulkanAPITest, reshape) { if (!at::is_vulkan_available()) { return; } + at::AutoNonVariableTypeMode nonVarTypeModeGuard(true); const auto in_cpu = at::rand({47, 11, 83, 97}, at::device(at::kCPU).dtype(at::kFloat)); const auto in_vulkan = in_cpu.vulkan(); @@ -638,6 +640,7 @@ TEST(VulkanAPITest, reshape_) { if (!at::is_vulkan_available()) { return; } + at::AutoNonVariableTypeMode nonVarTypeModeGuard(true); const auto cpu = at::rand({59, 41, 19, 67}, at::device(at::kCPU).dtype(at::kFloat)); const auto vulkan = cpu.vulkan(); diff --git a/binaries/compare_models_torch.cc b/binaries/compare_models_torch.cc index 6275087fd4fab..9dbce72d0e838 100644 --- a/binaries/compare_models_torch.cc +++ b/binaries/compare_models_torch.cc @@ -224,6 +224,7 @@ int main(int argc, char** argv) { float tolerance = 0; ss >> tolerance; + at::AutoNonVariableTypeMode nonVarTypeModeGuard(true); torch::autograd::AutoGradMode guard(false); torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false); auto module = torch::jit::load(FLAGS_model); diff --git a/binaries/speed_benchmark_torch.cc b/binaries/speed_benchmark_torch.cc index 88cc0b5dd9562..f8db31436801e 100644 --- a/binaries/speed_benchmark_torch.cc +++ b/binaries/speed_benchmark_torch.cc @@ -209,6 +209,7 @@ int main(int argc, char** argv) { std::vector inputs = create_inputs(); + at::AutoNonVariableTypeMode nonVarTypeModeGuard(true); torch::autograd::AutoGradMode guard(false); torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false); auto module = torch::jit::load(FLAGS_model);