Skip to content

Commit

Permalink
Merge pull request #803 from LaurentMazare/v2.1
Browse files Browse the repository at this point in the history
PyTorch v2.1 support
  • Loading branch information
LaurentMazare authored Oct 4, 2023
2 parents adf6dfa + 5480d6f commit dca12b6
Show file tree
Hide file tree
Showing 14 changed files with 204,772 additions and 1,246 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Changed

## v0.14.0
### Changed
- PyTorch v2.1 support
[803](https://github.com/LaurentMazare/tch-rs/pull/803).
- Add a `pyo3-tch` crate for interacting with Python via PyO3
[730](https://github.com/LaurentMazare/tch-rs/pull/730).
- Expose the cuda fuser enabled flag,
Expand Down
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tch"
version = "0.13.0"
version = "0.14.0"
authors = ["Laurent Mazare <lmazare@gmail.com>"]
edition = "2021"
build = "build.rs"
Expand All @@ -22,7 +22,7 @@ libc = "0.2.0"
ndarray = "0.15"
rand = "0.8"
thiserror = "1"
torch-sys = { version = "0.13.0", path = "torch-sys" }
torch-sys = { version = "0.14.0", path = "torch-sys" }
zip = "0.6"
half = "2"
safetensors = "0.3.0"
Expand All @@ -35,7 +35,7 @@ serde_json = { version = "1.0.96", optional = true }
memmap2 = { version = "0.6.1", optional = true }

[dev-dependencies]
anyhow = "1"
anyhow = "^1.0.60"

[workspace]
members = [
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The code generation part for the C api on top of libtorch comes from

## Getting Started

This crate requires the C++ PyTorch library (libtorch) in version *v2.0.0* to be available on
This crate requires the C++ PyTorch library (libtorch) in version *v2.1.0* to be available on
your system. You can either:

- Use the system-wide libtorch installation (default).
Expand Down Expand Up @@ -85,7 +85,7 @@ seem to include `libtorch.a` by default so this would have to be compiled
manually, e.g. via the following:

```bash
git clone -b v2.0.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1
git clone -b v2.1.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1
cd pytorch-static
USE_CUDA=OFF BUILD_SHARED_LIBS=OFF python setup.py build
# export LIBTORCH to point at the build directory in pytorch-static.
Expand Down
8 changes: 4 additions & 4 deletions examples/python-extension/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tch-ext"
version = "0.1.0"
version = "0.2.0"
authors = ["Laurent Mazare <lmazare@gmail.com>"]
edition = "2021"
build = "build.rs"
Expand All @@ -18,6 +18,6 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.18.3", features = ["extension-module"] }
pyo3-tch = { path = "../../pyo3-tch", version = "0.13.0" }
tch = { path = "../..", features = ["python-extension"], version = "0.13.0" }
torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.13.0" }
pyo3-tch = { path = "../../pyo3-tch", version = "0.14.0" }
tch = { path = "../..", features = ["python-extension"], version = "0.14.0" }
torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.14.0" }
3 changes: 2 additions & 1 deletion gen/gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ let excluded_prefixes =
; "_amp_foreach"
; "_nested_tensor"
; "_fused_adam"
; "sym_"
]

let excluded_suffixes = [ "_forward"; "_forward_out" ]
Expand Down Expand Up @@ -877,7 +878,7 @@ let run

let () =
run
~yaml_filename:"third_party/pytorch/Declarations-v2.0.0.yaml"
~yaml_filename:"third_party/pytorch/Declarations-v2.1.0.yaml"
~cpp_filename:"torch-sys/libtch/torch_api_generated"
~ffi_filename:"torch-sys/src/c_generated.rs"
~wrapper_filename:"src/wrappers/tensor_generated.rs"
Expand Down
6 changes: 3 additions & 3 deletions pyo3-tch/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pyo3-tch"
version = "0.13.0"
version = "0.14.0"
authors = ["Laurent Mazare <lmazare@gmail.com>"]
edition = "2021"
build = "build.rs"
Expand All @@ -12,6 +12,6 @@ categories = ["science"]
license = "MIT/Apache-2.0"

[dependencies]
tch = { path = "..", features = ["python-extension"], version = "0.13.0" }
torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.13.0" }
tch = { path = "..", features = ["python-extension"], version = "0.14.0" }
torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.14.0" }
pyo3 = { version = "0.18.3", features = ["extension-module"] }
Loading

0 comments on commit dca12b6

Please sign in to comment.