Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Numpy 2.x #429

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions src/borrow/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rustc_hash::FxHashMap;
use crate::array::get_array_module;
use crate::cold;
use crate::error::BorrowError;
use crate::npyffi::{PyArrayObject, PyArray_Check, NPY_ARRAY_WRITEABLE};
use crate::npyffi::{PyArrayObject, PyArray_Check, PyDataType_ELSIZE, NPY_ARRAY_WRITEABLE};

/// Defines the shared C API used for borrow checking
///
Expand Down Expand Up @@ -48,7 +48,7 @@ unsafe extern "C" fn acquire_shared(flags: *mut c_void, array: *mut PyArrayObjec
let flags = &mut *(flags as *mut BorrowFlags);

let address = base_address(py, array);
let key = borrow_key(array);
let key = borrow_key(py, array);

match flags.acquire(address, key) {
Ok(()) => 0,
Expand All @@ -66,7 +66,7 @@ unsafe extern "C" fn acquire_mut_shared(flags: *mut c_void, array: *mut PyArrayO
let flags = &mut *(flags as *mut BorrowFlags);

let address = base_address(py, array);
let key = borrow_key(array);
let key = borrow_key(py, array);

match flags.acquire_mut(address, key) {
Ok(()) => 0,
Expand All @@ -80,7 +80,7 @@ unsafe extern "C" fn release_shared(flags: *mut c_void, array: *mut PyArrayObjec
let flags = &mut *(flags as *mut BorrowFlags);

let address = base_address(py, array);
let key = borrow_key(array);
let key = borrow_key(py, array);

flags.release(address, key);
}
Expand All @@ -91,7 +91,7 @@ unsafe extern "C" fn release_mut_shared(flags: *mut c_void, array: *mut PyArrayO
let flags = &mut *(flags as *mut BorrowFlags);

let address = base_address(py, array);
let key = borrow_key(array);
let key = borrow_key(py, array);

flags.release_mut(address, key);
}
Expand Down Expand Up @@ -379,8 +379,8 @@ fn base_address<'py>(py: Python<'py>, mut array: *mut PyArrayObject) -> *mut c_v
}
}

fn borrow_key(array: *mut PyArrayObject) -> BorrowKey {
let range = data_range(array);
fn borrow_key<'py>(py: Python<'py>, array: *mut PyArrayObject) -> BorrowKey {
let range = data_range(py, array);

let data_ptr = unsafe { (*array).data };
let gcd_strides = gcd_strides(array);
Expand All @@ -392,7 +392,7 @@ fn borrow_key(array: *mut PyArrayObject) -> BorrowKey {
}
}

fn data_range(array: *mut PyArrayObject) -> (*mut c_char, *mut c_char) {
fn data_range<'py>(py: Python<'py>, array: *mut PyArrayObject) -> (*mut c_char, *mut c_char) {
let nd = unsafe { (*array).nd } as usize;
let data = unsafe { (*array).data };

Expand All @@ -403,7 +403,7 @@ fn data_range(array: *mut PyArrayObject) -> (*mut c_char, *mut c_char) {
let shape = unsafe { from_raw_parts((*array).dimensions as *mut usize, nd) };
let strides = unsafe { from_raw_parts((*array).strides, nd) };

let itemsize = unsafe { (*(*array).descr).elsize } as isize;
let itemsize = unsafe { PyDataType_ELSIZE(py, (*array).descr) } as isize;

let mut start = 0;
let mut end = 0;
Expand Down Expand Up @@ -468,7 +468,7 @@ mod tests {
let base_address = base_address(py, array.as_array_ptr());
assert_eq!(base_address, array.as_ptr().cast());

let data_range = data_range(array.as_array_ptr());
let data_range = data_range(py, array.as_array_ptr());
assert_eq!(data_range.0, array.data() as *mut c_char);
assert_eq!(data_range.1, unsafe { array.data().add(6) } as *mut c_char);
});
Expand All @@ -486,7 +486,7 @@ mod tests {
assert_ne!(base_address, array.as_ptr().cast());
assert_eq!(base_address, base.cast::<c_void>());

let data_range = data_range(array.as_array_ptr());
let data_range = data_range(py, array.as_array_ptr());
assert_eq!(data_range.0, array.data().cast::<c_char>());
assert_eq!(data_range.1, unsafe {
array.data().add(6).cast::<c_char>()
Expand Down Expand Up @@ -517,7 +517,7 @@ mod tests {
assert_ne!(base_address, view.as_ptr().cast::<c_void>());
assert_eq!(base_address, base.cast::<c_void>());

let data_range = data_range(view.as_array_ptr());
let data_range = data_range(py, view.as_array_ptr());
assert_eq!(data_range.0, array.data() as *mut c_char);
assert_eq!(data_range.1, unsafe { array.data().add(4) } as *mut c_char);
});
Expand Down Expand Up @@ -550,7 +550,7 @@ mod tests {
assert_ne!(base_address, array.as_ptr().cast::<c_void>());
assert_eq!(base_address, base.cast::<c_void>());

let data_range = data_range(view.as_array_ptr());
let data_range = data_range(py, view.as_array_ptr());
assert_eq!(data_range.0, array.data().cast::<c_char>());
assert_eq!(data_range.1, unsafe {
array.data().add(4).cast::<c_char>()
Expand Down Expand Up @@ -600,7 +600,7 @@ mod tests {
assert_ne!(base_address, view1.as_ptr().cast::<c_void>());
assert_eq!(base_address, base as *mut c_void);

let data_range = data_range(view2.as_array_ptr());
let data_range = data_range(py, view2.as_array_ptr());
assert_eq!(data_range.0, array.data() as *mut c_char);
assert_eq!(data_range.1, unsafe { array.data().add(1) } as *mut c_char);
});
Expand Down Expand Up @@ -652,7 +652,7 @@ mod tests {
assert_ne!(base_address, array.as_ptr().cast::<c_void>());
assert_eq!(base_address, base.cast::<c_void>());

let data_range = data_range(view2.as_array_ptr());
let data_range = data_range(py, view2.as_array_ptr());
assert_eq!(data_range.0, array.data().cast::<c_char>());
assert_eq!(data_range.1, unsafe {
array.data().add(1).cast::<c_char>()
Expand Down Expand Up @@ -683,7 +683,7 @@ mod tests {
assert_ne!(base_address, view.as_ptr().cast::<c_void>());
assert_eq!(base_address, base.cast::<c_void>());

let data_range = data_range(view.as_array_ptr());
let data_range = data_range(py, view.as_array_ptr());
assert_eq!(view.data(), unsafe { array.data().offset(2) });
assert_eq!(data_range.0, unsafe { view.data().offset(-2) }
as *mut c_char);
Expand All @@ -703,7 +703,7 @@ mod tests {
let base_address = base_address(py, array.as_array_ptr());
assert_eq!(base_address, array.as_ptr().cast::<c_void>());

let data_range = data_range(array.as_array_ptr());
let data_range = data_range(py, array.as_array_ptr());
assert_eq!(data_range.0, array.data() as *mut c_char);
assert_eq!(data_range.1, array.data() as *mut c_char);
});
Expand All @@ -721,7 +721,7 @@ mod tests {
.downcast_into::<PyArray2<f64>>()
.unwrap();

let key1 = borrow_key(view1.as_array_ptr());
let key1 = borrow_key(py, view1.as_array_ptr());

assert_eq!(view1.strides(), &[80, 24]);
assert_eq!(key1.gcd_strides, 8);
Expand All @@ -732,7 +732,7 @@ mod tests {
.downcast_into::<PyArray2<f64>>()
.unwrap();

let key2 = borrow_key(view2.as_array_ptr());
let key2 = borrow_key(py, view2.as_array_ptr());

assert_eq!(view2.strides(), &[80, 24]);
assert_eq!(key2.gcd_strides, 8);
Expand All @@ -743,7 +743,7 @@ mod tests {
.downcast_into::<PyArray2<f64>>()
.unwrap();

let key3 = borrow_key(view3.as_array_ptr());
let key3 = borrow_key(py, view3.as_array_ptr());

assert_eq!(view3.strides(), &[80, 16]);
assert_eq!(key3.gcd_strides, 16);
Expand All @@ -754,7 +754,7 @@ mod tests {
.downcast_into::<PyArray2<f64>>()
.unwrap();

let key4 = borrow_key(view4.as_array_ptr());
let key4 = borrow_key(py, view4.as_array_ptr());

assert_eq!(view4.strides(), &[80, 16]);
assert_eq!(key4.gcd_strides, 16);
Expand All @@ -777,7 +777,7 @@ mod tests {
let base1 = base_address(py, array1.as_array_ptr());
let base2 = base_address(py, array2.as_array_ptr());

let key1 = borrow_key(array1.as_array_ptr());
let key1 = borrow_key(py, array1.as_array_ptr());
let _exclusive1 = array1.readwrite();

{
Expand All @@ -791,7 +791,7 @@ mod tests {
assert_eq!(flag, -1);
}

let key2 = borrow_key(array2.as_array_ptr());
let key2 = borrow_key(py, array2.as_array_ptr());
let _shared2 = array2.readonly();

{
Expand Down Expand Up @@ -827,7 +827,7 @@ mod tests {
.downcast_into::<PyArray1<f64>>()
.unwrap();

let key1 = borrow_key(view1.as_array_ptr());
let key1 = borrow_key(py, view1.as_array_ptr());
let exclusive1 = view1.readwrite();

{
Expand All @@ -847,7 +847,7 @@ mod tests {
.downcast_into::<PyArray1<f64>>()
.unwrap();

let key2 = borrow_key(view2.as_array_ptr());
let key2 = borrow_key(py, view2.as_array_ptr());
let shared2 = view2.readonly();

{
Expand All @@ -870,7 +870,7 @@ mod tests {
.downcast_into::<PyArray1<f64>>()
.unwrap();

let key3 = borrow_key(view3.as_array_ptr());
let key3 = borrow_key(py, view3.as_array_ptr());
let shared3 = view3.readonly();

{
Expand All @@ -896,7 +896,7 @@ mod tests {
.downcast_into::<PyArray1<f64>>()
.unwrap();

let key4 = borrow_key(view4.as_array_ptr());
let key4 = borrow_key(py, view4.as_array_ptr());
let shared4 = view4.readonly();

{
Expand Down
6 changes: 4 additions & 2 deletions src/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ use pyo3::{sync::GILProtected, Bound, Py, Python};
use rustc_hash::FxHashMap;

use crate::dtype::{Element, PyArrayDescr, PyArrayDescrMethods};
use crate::npyffi::{PyArray_DatetimeDTypeMetaData, NPY_DATETIMEUNIT, NPY_TYPES};
use crate::npyffi::{
PyArray_DatetimeDTypeMetaData, PyDataType_C_METADATA, NPY_DATETIMEUNIT, NPY_TYPES,
};

/// Represents the [datetime units][datetime-units] supported by NumPy
///
Expand Down Expand Up @@ -230,7 +232,7 @@ impl TypeDescriptors {

// SAFETY: `self.npy_type` is either `NPY_DATETIME` or `NPY_TIMEDELTA` which implies the type of `c_metadata`.
unsafe {
let metadata = &mut *((*dtype.as_dtype_ptr()).c_metadata
let metadata = &mut *(PyDataType_C_METADATA(py, dtype.as_dtype_ptr())
as *mut PyArray_DatetimeDTypeMetaData);

metadata.meta.base = unit;
Expand Down
Loading