Skip to content

Commit

Permalink
Read safetensors files.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Apr 25, 2023
1 parent 88dc470 commit d860cb9
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 44 deletions.
72 changes: 38 additions & 34 deletions bin/tensor_tools.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,44 @@ let npz_tensors ~filename ~f =

let ls files =
List.iter files ~f:(fun filename ->
Stdio.printf "%s:\n" filename;
let tensor_names_and_shapes =
if String.is_suffix filename ~suffix:".npz"
then
npz_tensors ~filename ~f:(fun tensor_name packed_tensor ->
match packed_tensor with
| Npy.P tensor ->
let tensor_shape = Bigarray.Genarray.dims tensor |> Array.to_list in
tensor_name, tensor_shape)
else
Serialize.load_all ~filename
|> List.map ~f:(fun (tensor_name, tensor) -> tensor_name, Tensor.shape tensor)
in
List.iter tensor_names_and_shapes ~f:(fun (tensor_name, shape) ->
let shape = List.map shape ~f:Int.to_string |> String.concat ~sep:", " in
Stdio.printf " %s (%s)\n" tensor_name shape))
Stdio.printf "%s:\n" filename;
let tensor_names_and_shapes =
if String.is_suffix filename ~suffix:".npz"
then
npz_tensors ~filename ~f:(fun tensor_name packed_tensor ->
match packed_tensor with
| Npy.P tensor ->
let tensor_shape = Bigarray.Genarray.dims tensor |> Array.to_list in
tensor_name, tensor_shape)
else if String.is_suffix filename ~suffix:".safetensors"
then
Safetensors.read filename
|> List.map ~f:(fun (tensor_name, tensor) -> tensor_name, Tensor.shape tensor)
else
Serialize.load_all ~filename
|> List.map ~f:(fun (tensor_name, tensor) -> tensor_name, Tensor.shape tensor)
in
List.iter tensor_names_and_shapes ~f:(fun (tensor_name, shape) ->
let shape = List.map shape ~f:Int.to_string |> String.concat ~sep:", " in
Stdio.printf " %s (%s)\n" tensor_name shape))

let npz_to_pytorch npz_src pytorch_dst =
let named_tensors =
npz_tensors ~filename:npz_src ~f:(fun tensor_name packed_tensor ->
match packed_tensor with
| Npy.P tensor ->
(match Bigarray.Genarray.layout tensor with
| Bigarray.C_layout -> tensor_name, Tensor.of_bigarray tensor
| Bigarray.Fortran_layout -> failwith "fortran layout is not supported"))
match packed_tensor with
| Npy.P tensor ->
(match Bigarray.Genarray.layout tensor with
| Bigarray.C_layout -> tensor_name, Tensor.of_bigarray tensor
| Bigarray.Fortran_layout -> failwith "fortran layout is not supported"))
in
Serialize.save_multi ~named_tensors ~filename:pytorch_dst

let image_to_tensor image_src pytorch_dst resize =
let resize =
Option.map resize ~f:(fun resize ->
match String.split_on_chars resize ~on:[ 'x'; ',' ] with
| [ w; h ] -> Int.of_string w, Int.of_string h
| _ -> Printf.failwithf "resize should have format WxH, e.g. 64x64" ())
match String.split_on_chars resize ~on:[ 'x'; ',' ] with
| [ w; h ] -> Int.of_string w, Int.of_string h
| _ -> Printf.failwithf "resize should have format WxH, e.g. 64x64" ())
in
let tensor =
if Caml.Sys.is_directory image_src
Expand All @@ -59,16 +63,16 @@ let pytorch_to_npz pytorch_src npz_dst =
let named_tensors = Serialize.load_all ~filename:pytorch_src in
let npz_file = Npy.Npz.open_out npz_dst in
List.iter named_tensors ~f:(fun (tensor_name, tensor) ->
let write kind =
let tensor = Tensor.to_bigarray tensor ~kind in
Npy.Npz.write npz_file tensor_name tensor
in
match Tensor.kind tensor with
| T Float -> write Bigarray.float32
| T Double -> write Bigarray.float64
| T Int -> write Bigarray.int32
| T Int64 -> write Bigarray.int64
| _ -> Printf.failwithf "unsupported tensor kind for %s" tensor_name ());
let write kind =
let tensor = Tensor.to_bigarray tensor ~kind in
Npy.Npz.write npz_file tensor_name tensor
in
match Tensor.kind tensor with
| T Float -> write Bigarray.float32
| T Double -> write Bigarray.float64
| T Int -> write Bigarray.int32
| T Int64 -> write Bigarray.int64
| _ -> Printf.failwithf "unsupported tensor kind for %s" tensor_name ());
Npy.Npz.close_out npz_file

let () =
Expand Down
2 changes: 1 addition & 1 deletion src/torch/dune
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
(library
(name torch)
(public_name torch)
(libraries base stdio torch_core))
(libraries base int_repr stdio torch_core yojson))
100 changes: 100 additions & 0 deletions src/torch/safetensors.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
open! Base

exception Read_error of string

let read_error fmt = Printf.ksprintf (fun s -> raise (Read_error s)) fmt

let map_file file_descr ~pos ~len =
Unix.map_file
file_descr
~pos:(Int_conversions.int_to_int64 pos)
Bigarray.Int8_unsigned
C_layout
false
[| len |]

let read ?only filename =
let only = Option.map only ~f:(Hash_set.of_list (module String)) in
Stdio.In_channel.with_file filename ~binary:true ~f:(fun in_c ->
let header_size =
match In_channel.really_input_string in_c 8 with
| None -> read_error "unexpected eof while reading header size"
| Some header_size -> header_size
in
let header_size =
Int_repr.String.get_uint64_le header_size ~pos:0
|> Int_repr.Uint64.to_base_int64_exn
|> Base.Int_conversions.int64_to_int_exn
in
let header =
match In_channel.really_input_string in_c header_size with
| None -> read_error "unexpected eof while reading header len:%d" header_size
| Some header -> header
in
let header =
match Yojson.Safe.from_string header with
| `Assoc assoc -> assoc
| _ -> read_error "header is not a json object"
in
let fd = Unix.descr_of_in_channel in_c in
List.filter_map header ~f:(function
| "__metadata__", _ -> None
| tensor_name, `Assoc details
when Option.value_map only ~default:true ~f:(fun only ->
Hash_set.mem only tensor_name) ->
let details = Hashtbl.of_alist_exn (module String) details in
let packed_ty =
match Hashtbl.find details "dtype" with
| None -> read_error "missing dtype for %s" tensor_name
| Some (`String "F16") -> Torch_core.Kind.T Half
| Some (`String "F32") -> T Float
| Some (`String "F64") -> T Double
| Some (`String "I64") -> T Int64
| Some (`String "I32") -> T Int
| Some (`String "I16") -> T Int16
| Some (`String "I8") -> T Int8
| Some (`String "U8") -> T Uint8
| Some dtype ->
read_error
"unexpected dtype for %s: %s"
tensor_name
(Yojson.Safe.to_string dtype)
in
let start_offset, stop_offset =
match Hashtbl.find details "data_offsets" with
| None -> read_error "missing data_offsets for %s" tensor_name
| Some (`List [ `Int start; `Int stop ]) -> start, stop
| Some dtype ->
read_error
"unexpected data_offsets for %s: %s"
tensor_name
(Yojson.Safe.to_string dtype)
in
let shape =
match Hashtbl.find details "shape" with
| None -> read_error "missing shape for %s" tensor_name
| Some (`List dims) ->
List.map dims ~f:(function
| `Int i -> i
| other ->
read_error
"unexpected shape for %s: %s"
tensor_name
(Yojson.Safe.to_string other))
| Some dtype ->
read_error
"unexpected shape for %s: %s"
tensor_name
(Yojson.Safe.to_string dtype)
in
let src =
map_file
fd
~pos:(8 + header_size + start_offset)
~len:(stop_offset - start_offset)
in
let tensor = Tensor.of_bigarray_bytes src packed_ty ~shape in
Some (tensor_name, tensor)
| _, `Assoc _ -> None
| tensor_name, _ ->
read_error "header details for %s is not a json object" tensor_name))
3 changes: 3 additions & 0 deletions src/torch/safetensors.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
open! Base

val read : ?only:string list -> string -> (string * Torch_core.Wrapper.Tensor.t) list
1 change: 1 addition & 0 deletions src/torch/torch.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module Module = Module
(* [Nn] is an alias to [Layer] to keep coherence with the pytorch names. *)
module Nn = Layer
module Optimizer = Optimizer
module Safetensors = Safetensors
module Scalar = Scalar
module Serialize = Serialize
module Tensor = Tensor
Expand Down
16 changes: 16 additions & 0 deletions src/wrapper/kind.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ let bool = Bool

type packed = T : _ t -> packed

let element_size_in_bytes : type a. a t -> int = function
| Uint8 -> 1
| Int8 -> 1
| Int16 -> 2
| Int -> 4
| Int64 -> 8
| Half -> 2
| Float -> 4
| Double -> 8
| ComplexHalf -> 4
| ComplexFloat -> 8
| ComplexDouble -> 16
| Bool -> 1

let packed_element_size_in_bytes (T t) = element_size_in_bytes t

(* Hardcoded, should match ScalarType.h *)
let to_int : type a. a t -> int = function
| Uint8 -> 0
Expand Down
2 changes: 2 additions & 0 deletions src/wrapper/kind.mli
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ val bool : [ `bool ] t

type packed = T : _ t -> packed

val element_size_in_bytes : _ t -> int
val packed_element_size_in_bytes : packed -> int
val to_int : _ t -> int
val packed_to_int : packed -> int
val of_int_exn : int -> packed
Expand Down
32 changes: 23 additions & 9 deletions src/wrapper/wrapper.ml
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,32 @@ module Tensor = struct
tensor_of_data
(bigarray_start genarray ga |> to_voidp)
(Array.to_list dims
|> List.map Int64.of_int
|> CArray.of_list int64_t
|> CArray.start)
|> List.map Int64.of_int
|> CArray.of_list int64_t
|> CArray.start)
(Array.length dims)
(Bigarray.kind_size_in_bytes kind)
(Kind.packed_to_int tensor_kind)
in
Gc.finalise free t;
t

let of_bigarray_bytes
(ga : (_, _, Bigarray.c_layout) Bigarray.Genarray.t)
packed_kind
~shape
=
let t =
tensor_of_data
(bigarray_start genarray ga |> to_voidp)
(List.map Int64.of_int shape |> CArray.of_list int64_t |> CArray.start)
(List.length shape)
(Kind.packed_element_size_in_bytes packed_kind)
(Kind.packed_to_int packed_kind)
in
Gc.finalise free t;
t

let copy_to_bigarray (type a b) t (ga : (b, a, Bigarray.c_layout) Bigarray.Genarray.t) =
let kind = Bigarray.Genarray.kind ga in
copy_data
Expand Down Expand Up @@ -191,9 +207,7 @@ module Tensor = struct

let sum t = sum t ~dtype:(kind t)
let mean t = mean t ~dtype:(kind t)

let to_raw_pointer t = t

let of_raw_pointer t = t
end

Expand Down Expand Up @@ -254,15 +268,15 @@ module Serialize = struct
let escape s =
String.map
(function
| '.' -> '|'
| c -> c)
| '.' -> '|'
| c -> c)
s

let unescape s =
String.map
(function
| '|' -> '.'
| c -> c)
| '|' -> '.'
| c -> c)
s

let load ~filename =
Expand Down
8 changes: 8 additions & 0 deletions src/wrapper/wrapper.mli
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@ end

module Tensor : sig
type t

include Wrapper_generated_intf.S with type t := t and type 'a scalar := 'a Scalar.t

val new_tensor : unit -> t
val float_vec : ?kind:[ `double | `float | `half ] -> float list -> t
val int_vec : ?kind:[ `int | `int16 | `int64 | `int8 | `uint8 ] -> int list -> t
val of_bigarray : (_, _, Bigarray.c_layout) Bigarray.Genarray.t -> t

val of_bigarray_bytes
: (int, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Genarray.t
-> Kind.packed
-> shape:int list
-> t

val copy_to_bigarray : t -> (_, _, Bigarray.c_layout) Bigarray.Genarray.t -> unit
val shape : t -> int list
val size : t -> int list
Expand Down

0 comments on commit d860cb9

Please sign in to comment.