From c63b0fc3bfbd79faf03dedd91a1c4b97a7474ff2 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 9 Jan 2024 15:45:46 +0000 Subject: [PATCH] feat: Const::from_bool function --- src/algorithm/const_fold.rs | 6 +++--- src/ops/constant.rs | 10 ++++++++++ .../arithmetic/float_ops/const_fold.rs | 6 +----- src/std_extensions/logic.rs | 12 ++---------- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs index 16e4fe573..e8b8f24a2 100644 --- a/src/algorithm/const_fold.rs +++ b/src/algorithm/const_fold.rs @@ -225,7 +225,7 @@ mod test { use crate::std_extensions::arithmetic::float_ops::FloatOps; use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES}; - use crate::std_extensions::logic::{self, const_from_bool, NaryLogic}; + use crate::std_extensions::logic::{self, NaryLogic}; use rstest::rstest; /// int to constant @@ -320,7 +320,7 @@ mod test { ) -> Result<(), Box> { let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let ins = ins.map(|b| build.add_load_const(const_from_bool(b)).unwrap()); + let ins = ins.map(|b| build.add_load_const(Const::from_bool(b)).unwrap()); let logic_op = build.add_dataflow_op(op.with_n_inputs(ins.len() as u64), ins)?; let reg = @@ -328,7 +328,7 @@ mod test { let mut h = build.finish_hugr_with_outputs(logic_op.outputs(), ®)?; constant_fold_pass(&mut h, ®); - assert_fully_folded(&h, &const_from_bool(out)); + assert_fully_folded(&h, &Const::from_bool(out)); Ok(()) } diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 64954a655..900384468 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -64,6 +64,16 @@ impl Const { Self::unit_sum(1, 2) } + /// Generate a constant equivalent of a boolean, + /// see [`Const::true_val`] and [`Const::false_val`]. + pub fn from_bool(b: bool) -> Self { + if b { + Self::true_val() + } else { + Self::false_val() + } + } + /// Constant "false" value, i.e. the first variant of Sum((), ()). pub fn false_val() -> Self { Self::unit_sum(0, 2) diff --git a/src/std_extensions/arithmetic/float_ops/const_fold.rs b/src/std_extensions/arithmetic/float_ops/const_fold.rs index 34d162f4d..130908d74 100644 --- a/src/std_extensions/arithmetic/float_ops/const_fold.rs +++ b/src/std_extensions/arithmetic/float_ops/const_fold.rs @@ -86,11 +86,7 @@ impl ConstFold for CmpFold { ) -> ConstFoldResult { let [f1, f2] = get_floats(consts)?; - let res = if (self.0)(f1, f2) { - ops::Const::true_val() - } else { - ops::Const::false_val() - }; + let res = ops::Const::from_bool((self.0)(f1, f2)); Some(vec![(0.into(), res)]) } diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index d3abc9d4c..bf0d4c86b 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -53,12 +53,12 @@ impl MakeOpDef for NaryLogic { NaryLogic::And => |consts: &_| { let inps = read_inputs(consts)?; let res = inps.into_iter().all(|x| x); - Some(vec![(0.into(), const_from_bool(res))]) + Some(vec![(0.into(), ops::Const::from_bool(res))]) }, NaryLogic::Or => |consts: &_| { let inps = read_inputs(consts)?; let res = inps.into_iter().any(|x| x); - Some(vec![(0.into(), const_from_bool(res))]) + Some(vec![(0.into(), ops::Const::from_bool(res))]) }, }) } @@ -206,14 +206,6 @@ fn read_inputs(consts: &[(IncomingPort, ops::Const)]) -> Option> { Some(inps) } -pub(crate) fn const_from_bool(res: bool) -> ops::Const { - if res { - ops::Const::true_val() - } else { - ops::Const::false_val() - } -} - #[cfg(test)] pub(crate) mod test { use super::{extension, ConcreteLogicOp, NaryLogic, NotOp, FALSE_NAME, TRUE_NAME};