diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index 6528009f6..2f042b7bd 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -55,9 +55,15 @@ impl ConstFold for LogicOp { (!res || inps.len() as u64 == 1) .then_some(vec![(0.into(), ops::Value::from_bool(res))]) } + Self::Xor => { + let inps = read_inputs(consts)?; + let res = inps.iter().fold(false, |acc, x| acc ^ *x); + (inps.len() as u64 == 2).then_some(vec![(0.into(), ops::Value::from_bool(res))]) + } } } } + /// Logic extension operation definitions. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] #[allow(missing_docs)] @@ -67,12 +73,13 @@ pub enum LogicOp { Or, Eq, Not, + Xor, } impl MakeOpDef for LogicOp { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { match self { - LogicOp::And | LogicOp::Or | LogicOp::Eq => { + LogicOp::And | LogicOp::Or | LogicOp::Eq | LogicOp::Xor => { Signature::new(vec![bool_t(); 2], vec![bool_t()]) } LogicOp::Not => Signature::new_endo(vec![bool_t()]), @@ -90,6 +97,7 @@ impl MakeOpDef for LogicOp { LogicOp::Or => "logical 'or'", LogicOp::Eq => "test if bools are equal", LogicOp::Not => "logical 'not'", + LogicOp::Xor => "logical 'xor'", } .to_string() } @@ -181,7 +189,7 @@ pub(crate) mod test { fn test_logic_extension() { let r: Arc = extension(); assert_eq!(r.name() as &str, "logic"); - assert_eq!(r.operations().count(), 4); + assert_eq!(r.operations().count(), 5); for op in LogicOp::iter() { assert_eq!( @@ -230,6 +238,8 @@ pub(crate) mod test { #[case(LogicOp::Eq, [false, false], true)] #[case(LogicOp::Not, [false], true)] #[case(LogicOp::Not, [true], false)] + #[case(LogicOp::Xor, [true, false], true)] + #[case(LogicOp::Xor, [true, true], false)] fn const_fold( #[case] op: LogicOp, #[case] ins: impl IntoIterator, @@ -256,6 +266,7 @@ pub(crate) mod test { #[case(LogicOp::Or, [None, Some(true)], Some(true))] #[case(LogicOp::Eq, [None, Some(true)], None)] #[case(LogicOp::Not, [None], None)] + #[case(LogicOp::Xor, [None, Some(true)], None)] fn partial_const_fold( #[case] op: LogicOp, #[case] ins: impl IntoIterator>, diff --git a/hugr-llvm/src/extension/logic.rs b/hugr-llvm/src/extension/logic.rs index 32a58923c..93df299b7 100644 --- a/hugr-llvm/src/extension/logic.rs +++ b/hugr-llvm/src/extension/logic.rs @@ -32,30 +32,10 @@ fn emit_logic_op<'c, H: HugrView>( inputs.push(bool_val.build_get_tag(builder)?); } let res = match lot { - LogicOp::And => { - let mut acc = inputs[0]; - for inp in inputs.into_iter().skip(1) { - acc = builder.build_and(acc, inp, "")?; - } - acc - } - LogicOp::Or => { - let mut acc = inputs[0]; - for inp in inputs.into_iter().skip(1) { - acc = builder.build_or(acc, inp, "")?; - } - acc - } - LogicOp::Eq => { - let x = inputs.pop().unwrap(); - let y = inputs.pop().unwrap(); - let mut acc = builder.build_int_compare(IntPredicate::EQ, x, y, "")?; - for inp in inputs { - let eq = builder.build_int_compare(IntPredicate::EQ, inp, x, "")?; - acc = builder.build_and(acc, eq, "")?; - } - acc - } + LogicOp::And => builder.build_and(inputs[0], inputs[1], "")?, + LogicOp::Or => builder.build_or(inputs[0], inputs[1], "")?, + LogicOp::Xor => builder.build_xor(inputs[0], inputs[1], "")?, + LogicOp::Eq => builder.build_int_compare(IntPredicate::EQ, inputs[0], inputs[1], "")?, LogicOp::Not => builder.build_not(inputs[0], "")?, op => { return Err(anyhow!("LogicOpEmitter: Unknown op: {op:?}")); @@ -80,6 +60,7 @@ pub fn add_logic_extensions<'a, H: HugrView + 'a>( .extension_op(logic::EXTENSION_ID, LogicOp::And.name(), emit_logic_op) .extension_op(logic::EXTENSION_ID, LogicOp::Or.name(), emit_logic_op) .extension_op(logic::EXTENSION_ID, LogicOp::Not.name(), emit_logic_op) + .extension_op(logic::EXTENSION_ID, LogicOp::Xor.name(), emit_logic_op) // Added Xor } impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { @@ -148,4 +129,11 @@ mod test { let hugr = test_logic_op(LogicOp::Not, 1); check_emission!(hugr, llvm_ctx); } + + #[rstest] + fn xor(mut llvm_ctx: TestContext) { + llvm_ctx.add_extensions(add_logic_extensions); + let hugr = test_logic_op(LogicOp::Xor, 2); + check_emission!(hugr, llvm_ctx); + } } diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__eq@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__eq@llvm14.snap index 4b3267e25..909e283f8 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__eq@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__eq@llvm14.snap @@ -10,7 +10,7 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %2 = icmp eq i1 %1, %0 + %2 = icmp eq i1 %0, %1 %3 = select i1 %2, i1 true, i1 false ret i1 %3 } diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__eq@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__eq@pre-mem2reg@llvm14.snap index 354c429a2..1ae725134 100644 --- a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__eq@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__eq@pre-mem2reg@llvm14.snap @@ -18,7 +18,7 @@ entry_block: ; preds = %alloca_block store i1 %1, i1* %"2_1", align 1 %"2_01" = load i1, i1* %"2_0", align 1 %"2_12" = load i1, i1* %"2_1", align 1 - %2 = icmp eq i1 %"2_12", %"2_01" + %2 = icmp eq i1 %"2_01", %"2_12" %3 = select i1 %2, i1 true, i1 false store i1 %3, i1* %"4_0", align 1 %"4_03" = load i1, i1* %"4_0", align 1 diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@llvm14.snap new file mode 100644 index 000000000..934ef459b --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@llvm14.snap @@ -0,0 +1,16 @@ +--- +source: hugr-llvm/src/extension/logic.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i1 @_hl.main.1(i1 %0, i1 %1) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %2 = xor i1 %0, %1 + %3 = select i1 %2, i1 true, i1 false + ret i1 %3 +} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..a5dcf022d --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@pre-mem2reg@llvm14.snap @@ -0,0 +1,28 @@ +--- +source: hugr-llvm/src/extension/logic.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i1 @_hl.main.1(i1 %0, i1 %1) { +alloca_block: + %"0" = alloca i1, align 1 + %"2_0" = alloca i1, align 1 + %"2_1" = alloca i1, align 1 + %"4_0" = alloca i1, align 1 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i1 %0, i1* %"2_0", align 1 + store i1 %1, i1* %"2_1", align 1 + %"2_01" = load i1, i1* %"2_0", align 1 + %"2_12" = load i1, i1* %"2_1", align 1 + %2 = xor i1 %"2_01", %"2_12" + %3 = select i1 %2, i1 true, i1 false + store i1 %3, i1* %"4_0", align 1 + %"4_03" = load i1, i1* %"4_0", align 1 + store i1 %"4_03", i1* %"0", align 1 + %"04" = load i1, i1* %"0", align 1 + ret i1 %"04" +} diff --git a/hugr-py/src/hugr/std/_json_defs/logic.json b/hugr-py/src/hugr/std/_json_defs/logic.json index 7f90392ff..ad9f02019 100644 --- a/hugr-py/src/hugr/std/_json_defs/logic.json +++ b/hugr-py/src/hugr/std/_json_defs/logic.json @@ -150,6 +150,37 @@ } }, "binary": false + }, + "Xor": { + "extension": "logic", + "name": "Xor", + "description": "logical 'xor'", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "runtime_reqs": [] + } + }, + "binary": false } } } diff --git a/specification/std_extensions/logic.json b/specification/std_extensions/logic.json index 7f90392ff..ad9f02019 100644 --- a/specification/std_extensions/logic.json +++ b/specification/std_extensions/logic.json @@ -150,6 +150,37 @@ } }, "binary": false + }, + "Xor": { + "extension": "logic", + "name": "Xor", + "description": "logical 'xor'", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "runtime_reqs": [] + } + }, + "binary": false } } }