From 73644c7a08e65a62b9a7577fa1350c4f142fc9aa Mon Sep 17 00:00:00 2001 From: laurent Date: Sat, 12 Dec 2020 13:12:12 +0000 Subject: [PATCH] Avoid applying the functor twice. --- src/torch/device.ml | 5 ++--- src/wrapper/device.ml | 6 ------ src/wrapper/device.mli | 2 -- src/wrapper/wrapper.ml | 2 ++ src/wrapper/wrapper.mli | 2 ++ 5 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/torch/device.ml b/src/torch/device.ml index 58cde71..b0bb92a 100644 --- a/src/torch/device.ml +++ b/src/torch/device.ml @@ -7,7 +7,6 @@ let cuda_if_available () = if Cuda.is_available () then Cuda 0 else Cpu let is_cuda = function | Cpu -> false | Cuda _ -> true -;; -let get_num_threads = Torch_core.Device.get_num_threads -let set_num_threads = Torch_core.Device.set_num_threads +let get_num_threads = Torch_core.Wrapper.get_num_threads +let set_num_threads = Torch_core.Wrapper.set_num_threads diff --git a/src/wrapper/device.ml b/src/wrapper/device.ml index 02495f0..5d82b1c 100644 --- a/src/wrapper/device.ml +++ b/src/wrapper/device.ml @@ -8,11 +8,5 @@ let to_int = function | Cuda i -> if i < 0 then Printf.sprintf "negative index for cuda device" |> failwith; i -;; let of_int i = if i < 0 then Cpu else Cuda i - -module C = Torch_bindings.C (Torch_generated) - -let get_num_threads = C.get_num_threads -let set_num_threads = C.set_num_threads diff --git a/src/wrapper/device.mli b/src/wrapper/device.mli index 5a70afa..a1148ae 100644 --- a/src/wrapper/device.mli +++ b/src/wrapper/device.mli @@ -6,5 +6,3 @@ type t = val to_int : t -> int val of_int : int -> t -val get_num_threads : unit -> int -val set_num_threads : int -> unit diff --git a/src/wrapper/wrapper.ml b/src/wrapper/wrapper.ml index 723b472..55d26b9 100644 --- a/src/wrapper/wrapper.ml +++ b/src/wrapper/wrapper.ml @@ -433,3 +433,5 @@ module Module = struct end let manual_seed seed = Wrapper_generated.C.manual_seed (Int64.of_int seed) +let set_num_threads = Wrapper_generated.C.set_num_threads +let get_num_threads = Wrapper_generated.C.get_num_threads diff --git a/src/wrapper/wrapper.mli b/src/wrapper/wrapper.mli index 2e433c7..4ac0375 100644 --- a/src/wrapper/wrapper.mli +++ b/src/wrapper/wrapper.mli @@ -1,4 +1,6 @@ val manual_seed : int -> unit +val set_num_threads : int -> unit +val get_num_threads : unit -> int module Scalar : sig type _ t