From abcd598b3bf80790db5437460bf4ab05b7f7dc72 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 30 Nov 2023 10:46:43 +0000 Subject: [PATCH] feat: IntOpType convenience struct --- src/ops/custom.rs | 17 ++- src/std_extensions/arithmetic/int_ops.rs | 155 ++++++++++++++++++----- 2 files changed, 141 insertions(+), 31 deletions(-) diff --git a/src/ops/custom.rs b/src/ops/custom.rs index f5c013d606..b179a7bb25 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -10,6 +10,7 @@ use crate::hugr::{HugrView, NodeType}; use crate::types::{type_param::TypeArg, FunctionType}; use crate::{Hugr, Node}; +use super::dataflow::DataflowOpTrait; use super::tag::OpTag; use super::{LeafOp, OpTrait, OpType}; @@ -74,7 +75,7 @@ impl ExternalOp { pub fn description(&self) -> &str { match self { Self::Opaque(op) => op.description.as_str(), - Self::Extension(ExtensionOp { def, .. }) => def.description(), + Self::Extension(ext_op) => DataflowOpTrait::description(ext_op), } } @@ -86,7 +87,7 @@ impl ExternalOp { .signature .clone() .expect("Op should have been serialized with signature."), - Self::Extension(ExtensionOp { signature, .. }) => signature.clone(), + Self::Extension(ext_op) => ext_op.signature(), } } } @@ -170,6 +171,18 @@ impl PartialEq for ExtensionOp { } } +impl DataflowOpTrait for ExtensionOp { + const TAG: OpTag = OpTag::Leaf; + + fn description(&self) -> &str { + self.def().description() + } + + fn signature(&self) -> FunctionType { + self.signature.clone() + } +} + impl Eq for ExtensionOp {} /// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`] diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 008c5d9884..7fa006c00b 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -2,8 +2,12 @@ use super::int_types::{get_log_width, int_tv, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{sum_with_error, BOOL_T}; -use crate::extension::simple_op::MakeOpDef; -use crate::extension::{CustomValidator, OpDef, SignatureFunc, ValidateJustArgs}; +use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}; +use crate::extension::{ + CustomValidator, ExtensionRegistry, OpDef, SignatureFunc, ValidateJustArgs, PRELUDE, +}; +use crate::ops::custom::ExtensionOp; +use crate::ops::OpName; use crate::type_row; use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; @@ -14,6 +18,7 @@ use crate::{ }; use lazy_static::lazy_static; +use smol_str::SmolStr; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; /// The extension identifier. @@ -39,7 +44,7 @@ impl ValidateJustArgs for IOValidator { /// Logic extension operation definitions. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] #[allow(missing_docs, non_camel_case_types)] -pub enum IntOps { +pub enum IntOpDef { iwiden_u, iwiden_s, inarrow_u, @@ -87,13 +92,13 @@ pub enum IntOps { irotr, } -impl MakeOpDef for IntOps { +impl MakeOpDef for IntOpDef { fn from_def(op_def: &OpDef) -> Result { crate::extension::simple_op::try_from_name(op_def.name()) } fn signature(&self) -> SignatureFunc { - use IntOps::*; + use IntOpDef::*; match self { iwiden_s | iwiden_u => CustomValidator::new_with_validator( int_polytype(2, vec![int_tv(0)], vec![int_tv(1)]), @@ -152,7 +157,7 @@ impl MakeOpDef for IntOps { } fn description(&self) -> String { - use IntOps::*; + use IntOpDef::*; match self { iwiden_u => "widen an unsigned integer to a wider one with the same value", @@ -241,18 +246,91 @@ lazy_static! { ExtensionSet::singleton(&super::int_types::EXTENSION_ID), ); - IntOps::load_all_ops(&mut extension).unwrap(); + IntOpDef::load_all_ops(&mut extension).unwrap(); extension }; + + /// Registry of extensions required to validate integer operations. + pub static ref INT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + super::int_types::EXTENSION.to_owned(), + EXTENSION.to_owned(), + ]) + .unwrap(); +} + +/// Concrete integer operation with either one or two integer widths set. +#[derive(Debug, Clone, PartialEq)] +pub struct IntOpType { + def: IntOpDef, + first_width: u64, + second_width: Option, +} + +impl OpName for IntOpType { + fn name(&self) -> SmolStr { + self.def.name() + } +} +impl MakeExtensionOp for IntOpType { + fn from_extension_op(ext_op: &ExtensionOp) -> Result { + let def = IntOpDef::from_def(ext_op.def())?; + let (first_width, second_width) = match *ext_op.args() { + [TypeArg::BoundedNat { n }] => (n, None), + [TypeArg::BoundedNat { n }, TypeArg::BoundedNat { n: n2 }] => (n, Some(n2)), + _ => return Err(SignatureError::InvalidTypeArgs.into()), + }; + Ok(Self { + def, + first_width, + second_width, + }) + } + + fn type_args(&self) -> Vec { + [Some(self.first_width), self.second_width] + .iter() + .flatten() + .map(|&n| TypeArg::BoundedNat { n }) + .collect() + } +} + +impl MakeRegisteredOp for IntOpType { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &INT_OPS_REGISTRY + } +} + +impl IntOpDef { + /// Initialize a concrete [IntOpType] from a [IntOpDef] which requires one + /// integer width set. + pub fn one_width(self, width: u64) -> IntOpType { + IntOpType { + def: self, + first_width: width, + second_width: None, + } + } + /// Initialize a concrete [IntOpType] from a [IntOpDef] which requires two + /// integer widths set. + pub fn two_widths(self, first_width: u64, second_width: u64) -> IntOpType { + IntOpType { + def: self, + first_width, + second_width: Some(second_width), + } + } } #[cfg(test)] mod test { - use crate::{ - extension::{ExtensionRegistry, PRELUDE}, - std_extensions::arithmetic::int_types::int_type, - }; + use crate::{ops::dataflow::DataflowOpTrait, std_extensions::arithmetic::int_types::int_type}; use super::*; @@ -271,33 +349,52 @@ mod test { } #[test] fn test_binary_signatures() { - let iwiden_s = EXTENSION.get_op("iwiden_s").unwrap(); - let reg = ExtensionRegistry::try_new([ - EXTENSION.to_owned(), - super::super::int_types::EXTENSION.to_owned(), - PRELUDE.to_owned(), - ]) - .unwrap(); assert_eq!( - iwiden_s.compute_signature(&[ta(3), ta(4)], ®).unwrap(), + // iwiden_s + // .compute_signature(&[ta(3), ta(4)], &INT_OPS_REGISTRY) + // .unwrap(), + IntOpDef::iwiden_s + .two_widths(3, 4) + .to_extension_op() + .unwrap() + .signature(), FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))],) ); - let iwiden_u = EXTENSION.get_op("iwiden_u").unwrap(); - iwiden_u - .compute_signature(&[ta(4), ta(3)], ®) - .unwrap_err(); + // let iwiden_u = EXTENSION.get_op("iwiden_u").unwrap(); + // iwiden_u + // .compute_signature(&[ta(4), ta(3)], &INT_OPS_REGISTRY) + // .unwrap_err(); - let inarrow_s = EXTENSION.get_op("inarrow_s").unwrap(); + assert!(IntOpDef::iwiden_u + .two_widths(4, 3) + .to_extension_op() + .is_none()); assert_eq!( - inarrow_s.compute_signature(&[ta(2), ta(1)], ®).unwrap(), + IntOpDef::inarrow_s + .two_widths(2, 1) + .to_extension_op() + .unwrap() + .signature(), FunctionType::new(vec![int_type(ta(2))], vec![sum_with_error(int_type(ta(1)))],) ); - let inarrow_u = EXTENSION.get_op("inarrow_u").unwrap(); - inarrow_u - .compute_signature(&[ta(1), ta(2)], ®) - .unwrap_err(); + assert!(IntOpDef::inarrow_u + .two_widths(1, 2) + .to_extension_op() + .is_none()); + } + + #[test] + fn test_conversions() { + let o = IntOpDef::itobool.one_width(5); + assert!(IntOpDef::itobool + .two_widths(1, 2) + .to_extension_op() + .is_none()); + let ext_op = o.clone().to_extension_op().unwrap(); + + assert_eq!(IntOpType::from_extension_op(&ext_op).unwrap(), o); } }