-
Notifications
You must be signed in to change notification settings - Fork 428
/
Copy pathdimension_mapping.rs
49 lines (41 loc) · 1.28 KB
/
dimension_mapping.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct DimensionMapping {
/// Which dimensions have selectors?
pub selectors: Vec<usize>,
// Which dim?
pub width: Option<usize>,
// Which dim?
pub height: Option<usize>,
/// Flip the width
pub invert_width: bool,
/// Flip the height
pub invert_height: bool,
// Which dim?
pub channel: Option<usize>,
}
impl DimensionMapping {
pub fn create(num_dim: usize) -> DimensionMapping {
// TODO(emilk): add a heuristic here for the default
DimensionMapping {
width: Some(1),
height: Some(0),
channel: None,
selectors: (2..num_dim).collect(),
invert_width: false,
invert_height: false,
}
}
/// Protect against old serialized data that is not up-to-date with the new tensor
pub fn is_valid(&self, num_dim: usize) -> bool {
fn is_valid(dim_selector: &Option<usize>, num_dim: usize) -> bool {
if let Some(dim) = dim_selector {
*dim < num_dim
} else {
true
}
}
is_valid(&self.width, num_dim)
&& is_valid(&self.height, num_dim)
&& is_valid(&self.channel, num_dim)
}
}