-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathfrom_numpy.rs
59 lines (55 loc) · 2.16 KB
/
from_numpy.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
use std::sync::Arc;
use arrow::datatypes::{
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray};
use numpy::{
dtype_bound, PyArray1, PyArrayDescr, PyArrayDescrMethods, PyArrayMethods, PyUntypedArray,
PyUntypedArrayMethods,
};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use crate::error::PyArrowResult;
pub fn from_numpy(py: Python, array: &Bound<PyUntypedArray>) -> PyArrowResult<ArrayRef> {
macro_rules! numpy_to_arrow {
($rust_type:ty, $arrow_type:ty) => {{
let arr = array.downcast::<PyArray1<$rust_type>>()?;
Ok(Arc::new(PrimitiveArray::<$arrow_type>::from(
arr.to_owned_array().to_vec(),
)))
}};
}
let dtype = array.dtype();
if is_type::<half::f16>(py, &dtype) {
numpy_to_arrow!(half::f16, Float16Type)
} else if is_type::<f32>(py, &dtype) {
numpy_to_arrow!(f32, Float32Type)
} else if is_type::<f64>(py, &dtype) {
numpy_to_arrow!(f64, Float64Type)
} else if is_type::<u8>(py, &dtype) {
numpy_to_arrow!(u8, UInt8Type)
} else if is_type::<u16>(py, &dtype) {
numpy_to_arrow!(u16, UInt16Type)
} else if is_type::<u32>(py, &dtype) {
numpy_to_arrow!(u32, UInt32Type)
} else if is_type::<u64>(py, &dtype) {
numpy_to_arrow!(u64, UInt64Type)
} else if is_type::<i8>(py, &dtype) {
numpy_to_arrow!(i8, Int8Type)
} else if is_type::<i16>(py, &dtype) {
numpy_to_arrow!(i16, Int16Type)
} else if is_type::<i32>(py, &dtype) {
numpy_to_arrow!(i32, Int32Type)
} else if is_type::<i64>(py, &dtype) {
numpy_to_arrow!(i64, Int64Type)
} else if is_type::<bool>(py, &dtype) {
let arr = array.downcast::<PyArray1<bool>>()?;
Ok(Arc::new(BooleanArray::from(arr.to_owned_array().to_vec())))
} else {
Err(PyValueError::new_err(format!("Unsupported data type {}", dtype)).into())
}
}
fn is_type<T: numpy::Element>(py: Python, dtype: &Bound<PyArrayDescr>) -> bool {
dtype.is_equiv_to(&dtype_bound::<T>(py))
}