Skip to content

Commit

Permalink
Avoid applying the functor twice.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Dec 12, 2020
1 parent 4828584 commit 73644c7
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 11 deletions.
5 changes: 2 additions & 3 deletions src/torch/device.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ let cuda_if_available () = if Cuda.is_available () then Cuda 0 else Cpu
let is_cuda = function
| Cpu -> false
| Cuda _ -> true
;;

let get_num_threads = Torch_core.Device.get_num_threads
let set_num_threads = Torch_core.Device.set_num_threads
let get_num_threads = Torch_core.Wrapper.get_num_threads
let set_num_threads = Torch_core.Wrapper.set_num_threads
6 changes: 0 additions & 6 deletions src/wrapper/device.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,5 @@ let to_int = function
| Cuda i ->
if i < 0 then Printf.sprintf "negative index for cuda device" |> failwith;
i
;;

let of_int i = if i < 0 then Cpu else Cuda i

module C = Torch_bindings.C (Torch_generated)

let get_num_threads = C.get_num_threads
let set_num_threads = C.set_num_threads
2 changes: 0 additions & 2 deletions src/wrapper/device.mli
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,3 @@ type t =

val to_int : t -> int
val of_int : int -> t
val get_num_threads : unit -> int
val set_num_threads : int -> unit
2 changes: 2 additions & 0 deletions src/wrapper/wrapper.ml
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,5 @@ module Module = struct
end

let manual_seed seed = Wrapper_generated.C.manual_seed (Int64.of_int seed)
let set_num_threads = Wrapper_generated.C.set_num_threads
let get_num_threads = Wrapper_generated.C.get_num_threads
2 changes: 2 additions & 0 deletions src/wrapper/wrapper.mli
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
val manual_seed : int -> unit
val set_num_threads : int -> unit
val get_num_threads : unit -> int

module Scalar : sig
type _ t
Expand Down

0 comments on commit 73644c7

Please sign in to comment.