|
| 1 | +use pyo3::exceptions::PyIndexError; |
| 2 | +use pyo3::prelude::*; |
| 3 | + |
| 4 | +use genimtools::common::consts::{PAD_CHR, PAD_START, PAD_END}; |
| 5 | + |
| 6 | +use crate::models::{PyRegion, PyTokenizedRegion}; |
| 7 | + |
| 8 | +#[pyclass(name = "TokenizedRegionSet")] |
| 9 | +#[derive(Clone, Debug)] |
| 10 | +pub struct PyTokenizedRegionSet { |
| 11 | + pub regions: Vec<PyRegion>, |
| 12 | + pub ids: Vec<u32>, |
| 13 | + curr: usize, |
| 14 | +} |
| 15 | + |
| 16 | +#[pymethods] |
| 17 | +impl PyTokenizedRegionSet { |
| 18 | + #[new] |
| 19 | + pub fn new(regions: Vec<PyRegion>, ids: Vec<u32>) -> Self { |
| 20 | + PyTokenizedRegionSet { |
| 21 | + regions, |
| 22 | + ids, |
| 23 | + curr: 0, |
| 24 | + } |
| 25 | + } |
| 26 | + |
| 27 | + #[getter] |
| 28 | + pub fn regions(&self) -> PyResult<Vec<PyRegion>> { |
| 29 | + Ok(self.regions.to_owned()) |
| 30 | + } |
| 31 | + |
| 32 | + #[getter] |
| 33 | + pub fn ids(&self) -> PyResult<Vec<u32>> { |
| 34 | + Ok(self.ids.clone()) |
| 35 | + } |
| 36 | + |
| 37 | + // this is wrong: the padding token might not be in the universe |
| 38 | + pub fn pad(&mut self, len: usize) { |
| 39 | + let pad_region = PyRegion { |
| 40 | + chr: PAD_CHR.to_string(), |
| 41 | + start: PAD_START as u32, |
| 42 | + end: PAD_END as u32, |
| 43 | + }; |
| 44 | + let pad_id = self.ids[0]; |
| 45 | + let pad_region_set = PyTokenizedRegionSet { |
| 46 | + regions: vec![pad_region; len], |
| 47 | + ids: vec![pad_id; len], |
| 48 | + curr: 0, |
| 49 | + }; |
| 50 | + self.regions.extend(pad_region_set.regions); |
| 51 | + self.ids.extend(pad_region_set.ids); |
| 52 | + } |
| 53 | + |
| 54 | + pub fn __repr__(&self) -> String { |
| 55 | + format!("TokenizedRegionSet({} regions)", self.regions.len()) |
| 56 | + } |
| 57 | + |
| 58 | + pub fn __len__(&self) -> usize { |
| 59 | + self.regions.len() |
| 60 | + } |
| 61 | + |
| 62 | + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { |
| 63 | + slf |
| 64 | + } |
| 65 | + |
| 66 | + pub fn __next__(&mut self) -> Option<PyTokenizedRegion> { |
| 67 | + if self.curr < self.regions.len() { |
| 68 | + let region = self.regions[self.curr].clone(); |
| 69 | + let id = self.ids[self.curr]; |
| 70 | + |
| 71 | + self.curr += 1; |
| 72 | + Some(PyTokenizedRegion::new(region, id)) |
| 73 | + } else { |
| 74 | + None |
| 75 | + } |
| 76 | + } |
| 77 | + |
| 78 | + pub fn __getitem__(&self, indx: isize) -> PyResult<PyTokenizedRegion> { |
| 79 | + let indx = if indx < 0 { |
| 80 | + self.regions.len() as isize + indx |
| 81 | + } else { |
| 82 | + indx |
| 83 | + }; |
| 84 | + if indx < 0 || indx >= self.regions.len() as isize { |
| 85 | + Err(PyIndexError::new_err("Index out of bounds")) |
| 86 | + } else { |
| 87 | + let region = self.regions[indx as usize].clone(); |
| 88 | + let id = self.ids[indx as usize]; |
| 89 | + Ok(PyTokenizedRegion::new(region, id)) |
| 90 | + } |
| 91 | + } |
| 92 | +} |
0 commit comments