@@ -7,6 +7,8 @@ use arrow2_convert::{serialize::ArrowSerialize, ArrowDeserialize, ArrowField, Ar
7
7
use crate :: msg_bundle:: Component ;
8
8
use crate :: { TensorDataType , TensorElement } ;
9
9
10
+ use super :: arrow_convert_shims:: BinaryBuffer ;
11
+
10
12
pub trait TensorTrait {
11
13
fn id ( & self ) -> TensorId ;
12
14
fn shape ( & self ) -> & [ TensorDimension ] ;
@@ -16,6 +18,7 @@ pub trait TensorTrait {
16
18
fn meaning ( & self ) -> TensorDataMeaning ;
17
19
fn get ( & self , index : & [ u64 ] ) -> Option < TensorElement > ;
18
20
fn dtype ( & self ) -> TensorDataType ;
21
+ fn size_in_bytes ( & self ) -> usize ;
19
22
}
20
23
21
24
// ----------------------------------------------------------------------------
@@ -154,7 +157,7 @@ impl ArrowDeserialize for TensorId {
154
157
#[ derive( Clone , Debug , PartialEq , ArrowField , ArrowSerialize , ArrowDeserialize ) ]
155
158
#[ arrow_field( type = "dense" ) ]
156
159
pub enum TensorData {
157
- U8 ( Vec < u8 > ) ,
160
+ U8 ( BinaryBuffer ) ,
158
161
U16 ( Buffer < u16 > ) ,
159
162
U32 ( Buffer < u32 > ) ,
160
163
U64 ( Buffer < u64 > ) ,
@@ -168,7 +171,7 @@ pub enum TensorData {
168
171
//F16(Vec<arrow2::types::f16>),
169
172
F32 ( Buffer < f32 > ) ,
170
173
F64 ( Buffer < f64 > ) ,
171
- JPEG ( Vec < u8 > ) ,
174
+ JPEG ( BinaryBuffer ) ,
172
175
}
173
176
174
177
/// Flattened `Tensor` data payload
@@ -404,6 +407,21 @@ impl TensorTrait for Tensor {
404
407
TensorData :: F64 ( _) => TensorDataType :: F64 ,
405
408
}
406
409
}
410
+
411
+ fn size_in_bytes ( & self ) -> usize {
412
+ match & self . data {
413
+ TensorData :: U8 ( buf) | TensorData :: JPEG ( buf) => buf. 0 . len ( ) ,
414
+ TensorData :: U16 ( buf) => buf. len ( ) ,
415
+ TensorData :: U32 ( buf) => buf. len ( ) ,
416
+ TensorData :: U64 ( buf) => buf. len ( ) ,
417
+ TensorData :: I8 ( buf) => buf. len ( ) ,
418
+ TensorData :: I16 ( buf) => buf. len ( ) ,
419
+ TensorData :: I32 ( buf) => buf. len ( ) ,
420
+ TensorData :: I64 ( buf) => buf. len ( ) ,
421
+ TensorData :: F32 ( buf) => buf. len ( ) ,
422
+ TensorData :: F64 ( buf) => buf. len ( ) ,
423
+ }
424
+ }
407
425
}
408
426
409
427
impl Component for Tensor {
@@ -535,7 +553,7 @@ impl<'a> TryFrom<&'a Tensor> for ::ndarray::ArrayViewD<'a, half::f16> {
535
553
536
554
#[ cfg( feature = "image" ) ]
537
555
#[ derive( thiserror:: Error , Debug ) ]
538
- pub enum ImageError {
556
+ pub enum TensorImageError {
539
557
#[ error( transparent) ]
540
558
Image ( #[ from] image:: ImageError ) ,
541
559
@@ -575,20 +593,22 @@ impl Tensor {
575
593
#[ cfg( not( target_arch = "wasm32" ) ) ]
576
594
pub fn tensor_from_jpeg_file (
577
595
image_path : impl AsRef < std:: path:: Path > ,
578
- ) -> Result < Self , ImageError > {
596
+ ) -> Result < Self , TensorImageError > {
579
597
let jpeg_bytes = std:: fs:: read ( image_path) ?;
580
598
Self :: tensor_from_jpeg_bytes ( jpeg_bytes)
581
599
}
582
600
583
601
/// Construct a tensor from the contents of a JPEG file.
584
602
///
585
603
/// Requires the `image` feature.
586
- pub fn tensor_from_jpeg_bytes ( jpeg_bytes : Vec < u8 > ) -> Result < Self , ImageError > {
604
+ pub fn tensor_from_jpeg_bytes ( jpeg_bytes : Vec < u8 > ) -> Result < Self , TensorImageError > {
587
605
use image:: ImageDecoder as _;
588
606
let jpeg = image:: codecs:: jpeg:: JpegDecoder :: new ( std:: io:: Cursor :: new ( & jpeg_bytes) ) ?;
589
607
if jpeg. color_type ( ) != image:: ColorType :: Rgb8 {
590
608
// TODO(emilk): support gray-scale jpeg as well
591
- return Err ( ImageError :: UnsupportedJpegColorType ( jpeg. color_type ( ) ) ) ;
609
+ return Err ( TensorImageError :: UnsupportedJpegColorType (
610
+ jpeg. color_type ( ) ,
611
+ ) ) ;
592
612
}
593
613
let ( w, h) = jpeg. dimensions ( ) ;
594
614
@@ -599,7 +619,7 @@ impl Tensor {
599
619
TensorDimension :: width( w as _) ,
600
620
TensorDimension :: depth( 3 ) ,
601
621
] ,
602
- data : TensorData :: JPEG ( jpeg_bytes) ,
622
+ data : TensorData :: JPEG ( jpeg_bytes. into ( ) ) ,
603
623
meaning : TensorDataMeaning :: Unknown ,
604
624
meter : None ,
605
625
} )
@@ -608,20 +628,20 @@ impl Tensor {
608
628
/// Construct a tensor from something that can be turned into a [`image::DynamicImage`].
609
629
///
610
630
/// Requires the `image` feature.
611
- pub fn from_image ( image : impl Into < image:: DynamicImage > ) -> Result < Self , ImageError > {
631
+ pub fn from_image ( image : impl Into < image:: DynamicImage > ) -> Result < Self , TensorImageError > {
612
632
Self :: from_dynamic_image ( image. into ( ) )
613
633
}
614
634
615
635
/// Construct a tensor from [`image::DynamicImage`].
616
636
///
617
637
/// Requires the `image` feature.
618
- pub fn from_dynamic_image ( image : image:: DynamicImage ) -> Result < Self , ImageError > {
638
+ pub fn from_dynamic_image ( image : image:: DynamicImage ) -> Result < Self , TensorImageError > {
619
639
let ( w, h) = ( image. width ( ) , image. height ( ) ) ;
620
640
621
641
let ( depth, data) = match image {
622
- image:: DynamicImage :: ImageLuma8 ( image) => ( 1 , TensorData :: U8 ( image. into_raw ( ) ) ) ,
623
- image:: DynamicImage :: ImageRgb8 ( image) => ( 3 , TensorData :: U8 ( image. into_raw ( ) ) ) ,
624
- image:: DynamicImage :: ImageRgba8 ( image) => ( 4 , TensorData :: U8 ( image. into_raw ( ) ) ) ,
642
+ image:: DynamicImage :: ImageLuma8 ( image) => ( 1 , TensorData :: U8 ( image. into_raw ( ) . into ( ) ) ) ,
643
+ image:: DynamicImage :: ImageRgb8 ( image) => ( 3 , TensorData :: U8 ( image. into_raw ( ) . into ( ) ) ) ,
644
+ image:: DynamicImage :: ImageRgba8 ( image) => ( 4 , TensorData :: U8 ( image. into_raw ( ) . into ( ) ) ) ,
625
645
image:: DynamicImage :: ImageLuma16 ( image) => {
626
646
( 1 , TensorData :: U16 ( image. into_raw ( ) . into ( ) ) )
627
647
}
@@ -649,7 +669,7 @@ impl Tensor {
649
669
}
650
670
_ => {
651
671
// It is very annoying that DynamicImage is #[non_exhaustive]
652
- return Err ( ImageError :: UnsupportedImageColorType ( image. color ( ) ) ) ;
672
+ return Err ( TensorImageError :: UnsupportedImageColorType ( image. color ( ) ) ) ;
653
673
}
654
674
} ;
655
675
@@ -741,7 +761,7 @@ fn test_concat_and_slice() {
741
761
size: 4 ,
742
762
name: None ,
743
763
} ] ,
744
- data: TensorData :: JPEG ( vec![ 1 , 2 , 3 , 4 ] ) ,
764
+ data: TensorData :: JPEG ( vec![ 1 , 2 , 3 , 4 ] . into ( ) ) ,
745
765
meaning: TensorDataMeaning :: Unknown ,
746
766
meter: Some ( 1000.0 ) ,
747
767
} ] ;
@@ -752,7 +772,7 @@ fn test_concat_and_slice() {
752
772
size: 4 ,
753
773
name: None ,
754
774
} ] ,
755
- data: TensorData :: JPEG ( vec![ 5 , 6 , 7 , 8 ] ) ,
775
+ data: TensorData :: JPEG ( vec![ 5 , 6 , 7 , 8 ] . into ( ) ) ,
756
776
meaning: TensorDataMeaning :: Unknown ,
757
777
meter: None ,
758
778
} ] ;
0 commit comments