Skip to content

Commit d07dbf5

Browse files
committed
apply check for core vs _core at runtime
1 parent 7e702be commit d07dbf5

File tree

3 files changed

+50
-5
lines changed

3 files changed

+50
-5
lines changed

src/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ pub type PyArrayDyn<T> = PyArray<T, IxDyn>;
122122

123123
/// Returns a handle to NumPy's multiarray module.
124124
pub fn get_array_module<'py>(py: Python<'py>) -> PyResult<Bound<'_, PyModule>> {
125-
PyModule::import_bound(py, npyffi::array::MOD_NAME)
125+
PyModule::import_bound(py, npyffi::array::mod_name(py)?)
126126
}
127127

128128
impl<T, D> DerefToPyAny for PyArray<T, D> {}

src/npyffi/array.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,43 @@ use pyo3::{
1414

1515
use crate::npyffi::*;
1616

17-
pub(crate) const MOD_NAME: &str = "numpy._core.multiarray";
17+
pub(crate) fn numpy_core_name(py: Python<'_>) -> PyResult<&'static str> {
18+
static MOD_NAME: GILOnceCell<&'static str> = GILOnceCell::new();
19+
20+
MOD_NAME
21+
.get_or_try_init(py, || {
22+
// numpy 2 renamed to numpy._core
23+
24+
// strategy mirrored from https://github.com/pybind/pybind11/blob/af67e87393b0f867ccffc2702885eea12de063fc/include/pybind11/numpy.h#L175-L195
25+
26+
let numpy = PyModule::import_bound(py, "numpy")?;
27+
let version_string = numpy.getattr("__version__")?;
28+
29+
let numpy_lib = PyModule::import_bound(py, "numpy.lib")?;
30+
let numpy_version = numpy_lib
31+
.getattr("NumpyVersion")?
32+
.call1((version_string,))?;
33+
let major_version: u8 = numpy_version.getattr("major")?.extract()?;
34+
35+
Ok(if major_version >= 2 {
36+
"numpy._core"
37+
} else {
38+
"numpy.core"
39+
})
40+
})
41+
.copied()
42+
}
43+
44+
pub(crate) fn mod_name(py: Python<'_>) -> PyResult<&'static str> {
45+
static MOD_NAME: GILOnceCell<String> = GILOnceCell::new();
46+
MOD_NAME
47+
.get_or_try_init(py, || {
48+
let numpy_core = numpy_core_name(py)?;
49+
Ok(format!("{}.multiarray", numpy_core))
50+
})
51+
.map(String::as_str)
52+
}
53+
1854
const CAPSULE_NAME: &str = "_ARRAY_API";
1955

2056
/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
@@ -49,7 +85,7 @@ impl PyArrayAPI {
4985
unsafe fn get<'py>(&self, py: Python<'py>, offset: isize) -> *const *const c_void {
5086
let api = self
5187
.0
52-
.get_or_try_init(py, || get_numpy_api(py, MOD_NAME, CAPSULE_NAME))
88+
.get_or_try_init(py, || get_numpy_api(py, mod_name(py)?, CAPSULE_NAME))
5389
.expect("Failed to access NumPy array API capsule");
5490

5591
api.offset(offset)

src/npyffi/ufunc.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,16 @@ use pyo3::{ffi::PyObject, sync::GILOnceCell};
66

77
use crate::npyffi::*;
88

9-
const MOD_NAME: &str = "numpy.core.umath";
9+
fn mod_name(py: Python<'_>) -> PyResult<&'static str> {
10+
static MOD_NAME: GILOnceCell<String> = GILOnceCell::new();
11+
MOD_NAME
12+
.get_or_try_init(py, || {
13+
let numpy_core = super::array::numpy_core_name(py)?;
14+
Ok(format!("{}.umath", numpy_core))
15+
})
16+
.map(String::as_str)
17+
}
18+
1019
const CAPSULE_NAME: &str = "_UFUNC_API";
1120

1221
/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
@@ -23,7 +32,7 @@ impl PyUFuncAPI {
2332
unsafe fn get<'py>(&self, py: Python<'py>, offset: isize) -> *const *const c_void {
2433
let api = self
2534
.0
26-
.get_or_try_init(py, || get_numpy_api(py, MOD_NAME, CAPSULE_NAME))
35+
.get_or_try_init(py, || get_numpy_api(py, mod_name(py)?, CAPSULE_NAME))
2736
.expect("Failed to access NumPy ufunc API capsule");
2837

2938
api.offset(offset)

0 commit comments

Comments
 (0)