Skip to content

Commit

Permalink
feat: Add Type::as_sum and SumType::variants. (#1914)
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q authored Feb 10, 2025
1 parent 5e7c81e commit 0124f7e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 7 deletions.
1 change: 1 addition & 0 deletions hugr-core/src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl std::ops::Deref for ConstF64 {

impl ConstF64 {
/// Name of the constructor for creating constant 64bit floats.
#[cfg_attr(not(feature = "model_unstable"), allow(dead_code))]
pub(crate) const CTR_NAME: &'static str = "arithmetic.float.const-f64";

/// Create a new [`ConstF64`]
Expand Down
1 change: 1 addition & 0 deletions hugr-core/src/std_extensions/arithmetic/int_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pub struct ConstInt {

impl ConstInt {
/// Name of the constructor for creating constant integers.
#[cfg_attr(not(feature = "model_unstable"), allow(dead_code))]
pub(crate) const CTR_NAME: &'static str = "arithmetic.int.const";

/// Create a new [`ConstInt`] with a given width and unsigned value
Expand Down
1 change: 1 addition & 0 deletions hugr-core/src/std_extensions/collections/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub struct ArrayValue {

impl ArrayValue {
/// Name of the constructor for creating constant arrays.
#[cfg_attr(not(feature = "model_unstable"), allow(dead_code))]
pub(crate) const CTR_NAME: &'static str = "collections.array.const";

/// Create a new [CustomConst] for an array of values of type `typ`.
Expand Down
57 changes: 50 additions & 7 deletions hugr-core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub use type_row::{TypeRow, TypeRowRV};
pub(crate) use poly_func::PolyFuncTypeBase;

use itertools::FoldWhile::{Continue, Done};
use itertools::{repeat_n, Itertools};
use itertools::{Either, Itertools as _};
#[cfg(test)]
use proptest_derive::Arbitrary;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -189,7 +189,7 @@ impl std::fmt::Display for SumType {
SumType::Unit { size: 1 } => write!(f, "Unit"),
SumType::Unit { size: 2 } => write!(f, "Bool"),
SumType::Unit { size } => {
display_list_with_separator(repeat_n("[]", *size as usize), f, "+")
display_list_with_separator(itertools::repeat_n("[]", *size as usize), f, "+")
}
SumType::General { rows } => match rows.len() {
1 if rows[0].is_empty() => write!(f, "Unit"),
Expand All @@ -216,17 +216,17 @@ impl SumType {
}
}

/// New UnitSum with empty Tuple variants
/// New UnitSum with empty Tuple variants.
pub const fn new_unary(size: u8) -> Self {
Self::Unit { size }
}

/// New tuple (single row of variants)
/// New tuple (single row of variants).
pub fn new_tuple(types: impl Into<TypeRow>) -> Self {
Self::new([types.into()])
}

/// New option type (either an empty option, or a row of types)
/// New option type (either an empty option, or a row of types).
pub fn new_option(types: impl Into<TypeRow>) -> Self {
Self::new([vec![].into(), types.into()])
}
Expand All @@ -248,14 +248,25 @@ impl SumType {
}
}

/// Returns variant row if there is only one variant
/// Returns variant row if there is only one variant.
pub fn as_tuple(&self) -> Option<&TypeRowRV> {
match self {
SumType::Unit { size } if *size == 1 => Some(TypeRV::EMPTY_TYPEROW_REF),
SumType::General { rows } if rows.len() == 1 => Some(&rows[0]),
_ => None,
}
}

/// Returns an iterator over the variants.
pub fn variants(&self) -> impl Iterator<Item = &TypeRowRV> {
match self {
SumType::Unit { size } => Either::Left(itertools::repeat_n(
TypeRV::EMPTY_TYPEROW_REF,
*size as usize,
)),
SumType::General { rows } => Either::Right(rows.iter()),
}
}
}

impl<RV: MaybeRV> From<SumType> for TypeBase<RV> {
Expand Down Expand Up @@ -453,6 +464,14 @@ impl<RV: MaybeRV> TypeBase<RV> {
&mut self.0
}

/// Returns the inner [SumType] if the type is a sum.
pub fn as_sum(&self) -> Option<&SumType> {
match &self.0 {
TypeEnum::Sum(s) => Some(s),
_ => None,
}
}

/// Report if the type is copyable - i.e.the least upper bound of the type
/// is contained by the copyable bound.
pub const fn copyable(&self) -> bool {
Expand Down Expand Up @@ -713,13 +732,37 @@ pub(crate) mod test {
assert_eq!(pred1, Type::from(pred_direct));
}

#[test]
fn as_sum() {
let t = Type::new_unit_sum(0);
assert!(t.as_sum().is_some());
}

#[test]
fn sum_variants() {
{
let variants: Vec<TypeRowRV> = vec![
TypeRV::UNIT.into(),
vec![TypeRV::new_row_var_use(0, TypeBound::Any)].into(),
];
let t = SumType::new(variants.clone());
assert_eq!(variants, t.variants().cloned().collect_vec());
}
{
assert_eq!(
vec![&TypeRV::EMPTY_TYPEROW; 3],
SumType::new_unary(3).variants().collect_vec()
);
}
}

mod proptest {

use crate::proptest::RecursionDepth;

use super::{AliasDecl, MaybeRV, TypeBase, TypeBound, TypeEnum};
use crate::types::{CustomType, FuncValueType, SumType, TypeRowRV};
use ::proptest::prelude::*;
use proptest::prelude::*;

impl Arbitrary for super::SumType {
type Parameters = RecursionDepth;
Expand Down

0 comments on commit 0124f7e

Please sign in to comment.