Skip to content

Commit 9d540d1

Browse files
committed
Add Expr support to QuantumCircuit.compose
This relatively straightforwardly generalises the mapping of variables that already exists in `QuantumCircuit.compose` to be able to handle arbitrary `Expr` nodes as well. We must take care not to accidentally mutate the `Expr` nodes in the input circuit in the name of efficiency.
1 parent 539557d commit 9d540d1

File tree

2 files changed

+172
-32
lines changed

2 files changed

+172
-32
lines changed

qiskit/circuit/quantumcircuit.py

+81-32
Original file line numberDiff line numberDiff line change
@@ -949,43 +949,16 @@ def compose(
949949
)
950950
edge_map.update(zip(other.clbits, dest.cbit_argument_conversion(clbits)))
951951

952-
# Cache for `map_register_to_dest`.
953-
_map_register_cache = {}
954-
955-
def map_register_to_dest(theirs):
956-
"""Map the target's registers to suitable equivalents in the destination, adding an
957-
extra one if there's no exact match."""
958-
if theirs.name in _map_register_cache:
959-
return _map_register_cache[theirs.name]
960-
mapped_bits = [edge_map[bit] for bit in theirs]
961-
for ours in dest.cregs:
962-
if mapped_bits == list(ours):
963-
mapped_theirs = ours
964-
break
965-
else:
966-
mapped_theirs = ClassicalRegister(bits=mapped_bits)
967-
dest.add_register(mapped_theirs)
968-
_map_register_cache[theirs.name] = mapped_theirs
969-
return mapped_theirs
970-
952+
variable_mapper = _ComposeVariableMapper(dest, edge_map)
971953
mapped_instrs: list[CircuitInstruction] = []
972954
for instr in other.data:
973955
n_qargs: list[Qubit] = [edge_map[qarg] for qarg in instr.qubits]
974956
n_cargs: list[Clbit] = [edge_map[carg] for carg in instr.clbits]
975957
n_op = instr.operation.copy()
976-
977-
if getattr(n_op, "condition", None) is not None:
978-
target, value = n_op.condition
979-
if isinstance(target, Clbit):
980-
n_op.condition = (edge_map[target], value)
981-
else:
982-
n_op.condition = (map_register_to_dest(target), value)
983-
elif isinstance(n_op, SwitchCaseOp):
984-
if isinstance(n_op.target, Clbit):
985-
n_op.target = edge_map[n_op.target]
986-
else:
987-
n_op.target = map_register_to_dest(n_op.target)
988-
958+
if (condition := getattr(n_op, "condition", None)) is not None:
959+
n_op.condition = variable_mapper.map_condition(condition)
960+
if isinstance(n_op, SwitchCaseOp):
961+
n_op.target = variable_mapper.map_target(n_op.target)
989962
mapped_instrs.append(CircuitInstruction(n_op, n_qargs, n_cargs))
990963

991964
if front:
@@ -5252,3 +5225,79 @@ def _bit_argument_conversion_scalar(specifier, bit_sequence, bit_set, type_):
52525225
else f"Invalid bit index: '{specifier}' of type '{type(specifier)}'"
52535226
)
52545227
raise CircuitError(message)
5228+
5229+
5230+
class _ComposeVariableMapper(expr.ExprVisitor[expr.Expr]):
5231+
"""Stateful helper class that manages the mapping of variables in conditions and expressions to
5232+
items in the destination ``circuit``.
5233+
5234+
This mutates ``circuit`` by adding registers as required."""
5235+
5236+
__slots__ = ("circuit", "register_map", "bit_map")
5237+
5238+
def __init__(self, circuit, bit_map):
5239+
self.circuit = circuit
5240+
self.register_map = {}
5241+
self.bit_map = bit_map
5242+
5243+
def _map_register(self, theirs):
5244+
"""Map the target's registers to suitable equivalents in the destination, adding an
5245+
extra one if there's no exact match."""
5246+
if (mapped_theirs := self.register_map.get(theirs.name)) is not None:
5247+
return mapped_theirs
5248+
mapped_bits = [self.bit_map[bit] for bit in theirs]
5249+
for ours in self.circuit.cregs:
5250+
if mapped_bits == list(ours):
5251+
mapped_theirs = ours
5252+
break
5253+
else:
5254+
mapped_theirs = ClassicalRegister(bits=mapped_bits)
5255+
self.circuit.add_register(mapped_theirs)
5256+
self.register_map[theirs.name] = mapped_theirs
5257+
return mapped_theirs
5258+
5259+
def map_condition(self, condition, /):
5260+
"""Map the given ``condition`` so that it only references variables in the destination
5261+
circuit (as given to this class on initialisation)."""
5262+
if condition is None:
5263+
return None
5264+
if isinstance(condition, expr.Expr):
5265+
return self.map_expr(condition)
5266+
target, value = condition
5267+
if isinstance(target, Clbit):
5268+
return (self.bit_map[target], value)
5269+
return (self._map_register(target), value)
5270+
5271+
def map_target(self, target, /):
5272+
"""Map the runtime variables in a ``target`` of a :class:`.SwitchCaseOp` to the new circuit,
5273+
as defined in the ``circuit`` argument of the initialiser of this class."""
5274+
if isinstance(target, Clbit):
5275+
return self.bit_map[target]
5276+
if isinstance(target, ClassicalRegister):
5277+
return self._map_register(target)
5278+
return self.map_expr(target)
5279+
5280+
def map_expr(self, node: expr.Expr, /) -> expr.Expr:
5281+
"""Map the variables in an :class:`~.expr.Expr` node to the new circuit."""
5282+
return node.accept(self)
5283+
5284+
def visit_var(self, node, /):
5285+
if isinstance(node.var, Clbit):
5286+
return expr.Var(self.bit_map[node.var], node.type)
5287+
if isinstance(node.var, ClassicalRegister):
5288+
return expr.Var(self._map_register(node.var), node.type)
5289+
# Defensive against the expansion of the variable system; we don't want to silently do the
5290+
# wrong thing (which would be `return node` without mapping, right now).
5291+
raise CircuitError(f"unhandled variable in 'compose': {node}") # pragma: no cover
5292+
5293+
def visit_value(self, node, /):
5294+
return expr.Value(node.value, node.type)
5295+
5296+
def visit_unary(self, node, /):
5297+
return expr.Unary(node.op, node.operand.accept(self), node.type)
5298+
5299+
def visit_binary(self, node, /):
5300+
return expr.Binary(node.op, node.left.accept(self), node.right.accept(self), node.type)
5301+
5302+
def visit_cast(self, node, /):
5303+
return expr.Cast(node.operand.accept(self), node.type, implicit=node.implicit)

test/python/circuit/test_compose.py

+91
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
SwitchCaseOp,
3434
)
3535
from qiskit.circuit.library import HGate, RZGate, CXGate, CCXGate, TwoLocal
36+
from qiskit.circuit.classical import expr
3637
from qiskit.test import QiskitTestCase
3738

3839

@@ -789,6 +790,96 @@ def test_compose_noclbits_registerless(self):
789790
self.assertEqual(outer.clbits, inner.clbits)
790791
self.assertEqual(outer.cregs, [])
791792

793+
def test_expr_condition_is_mapped(self):
794+
"""Test that an expression in a condition involving several registers is mapped correctly to
795+
the destination circuit."""
796+
inner = QuantumCircuit(1)
797+
inner.x(0)
798+
a_src = ClassicalRegister(2, "a_src")
799+
b_src = ClassicalRegister(2, "b_src")
800+
c_src = ClassicalRegister(name="c_src", bits=list(a_src) + list(b_src))
801+
source = QuantumCircuit(QuantumRegister(1), a_src, b_src, c_src)
802+
803+
test_1 = lambda: expr.lift(a_src[0])
804+
test_2 = lambda: expr.logic_not(b_src[1])
805+
test_3 = lambda: expr.logic_and(expr.bit_and(b_src, 2), expr.less(c_src, 7))
806+
source.if_test(test_1(), inner.copy(), [0], [])
807+
source.if_else(test_2(), inner.copy(), inner.copy(), [0], [])
808+
source.while_loop(test_3(), inner.copy(), [0], [])
809+
810+
a_dest = ClassicalRegister(2, "a_dest")
811+
b_dest = ClassicalRegister(2, "b_dest")
812+
dest = QuantumCircuit(QuantumRegister(1), a_dest, b_dest).compose(source)
813+
814+
# Check that the input conditions weren't mutated.
815+
for in_condition, instruction in zip((test_1, test_2, test_3), source.data):
816+
self.assertEqual(in_condition(), instruction.operation.condition)
817+
818+
# Should be `a_dest`, `b_dest` and an added one to account for `c_src`.
819+
self.assertEqual(len(dest.cregs), 3)
820+
mapped_reg = dest.cregs[-1]
821+
822+
expected = QuantumCircuit(dest.qregs[0], a_dest, b_dest, mapped_reg)
823+
expected.if_test(expr.lift(a_dest[0]), inner.copy(), [0], [])
824+
expected.if_else(expr.logic_not(b_dest[1]), inner.copy(), inner.copy(), [0], [])
825+
expected.while_loop(
826+
expr.logic_and(expr.bit_and(b_dest, 2), expr.less(mapped_reg, 7)), inner.copy(), [0], []
827+
)
828+
self.assertEqual(dest, expected)
829+
830+
def test_expr_target_is_mapped(self):
831+
"""Test that an expression in a switch statement's target is mapping correctly to the
832+
destination circuit."""
833+
inner1 = QuantumCircuit(1)
834+
inner1.x(0)
835+
inner2 = QuantumCircuit(1)
836+
inner2.z(0)
837+
838+
a_src = ClassicalRegister(2, "a_src")
839+
b_src = ClassicalRegister(2, "b_src")
840+
c_src = ClassicalRegister(name="c_src", bits=list(a_src) + list(b_src))
841+
source = QuantumCircuit(QuantumRegister(1), a_src, b_src, c_src)
842+
843+
test_1 = lambda: expr.lift(a_src[0])
844+
test_2 = lambda: expr.logic_not(b_src[1])
845+
test_3 = lambda: expr.lift(b_src)
846+
test_4 = lambda: expr.bit_and(c_src, 7)
847+
source.switch(test_1(), [(False, inner1.copy()), (True, inner2.copy())], [0], [])
848+
source.switch(test_2(), [(False, inner1.copy()), (True, inner2.copy())], [0], [])
849+
source.switch(test_3(), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], [])
850+
source.switch(test_4(), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], [])
851+
852+
a_dest = ClassicalRegister(2, "a_dest")
853+
b_dest = ClassicalRegister(2, "b_dest")
854+
dest = QuantumCircuit(QuantumRegister(1), a_dest, b_dest).compose(source)
855+
856+
# Check that the input expressions weren't mutated.
857+
for in_target, instruction in zip((test_1, test_2, test_3, test_4), source.data):
858+
self.assertEqual(in_target(), instruction.operation.target)
859+
860+
# Should be `a_dest`, `b_dest` and an added one to account for `c_src`.
861+
self.assertEqual(len(dest.cregs), 3)
862+
mapped_reg = dest.cregs[-1]
863+
864+
expected = QuantumCircuit(dest.qregs[0], a_dest, b_dest, mapped_reg)
865+
expected.switch(
866+
expr.lift(a_dest[0]), [(False, inner1.copy()), (True, inner2.copy())], [0], []
867+
)
868+
expected.switch(
869+
expr.logic_not(b_dest[1]), [(False, inner1.copy()), (True, inner2.copy())], [0], []
870+
)
871+
expected.switch(
872+
expr.lift(b_dest), [(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())], [0], []
873+
)
874+
expected.switch(
875+
expr.bit_and(mapped_reg, 7),
876+
[(0, inner1.copy()), (CASE_DEFAULT, inner2.copy())],
877+
[0],
878+
[],
879+
)
880+
881+
self.assertEqual(dest, expected)
882+
792883

793884
if __name__ == "__main__":
794885
unittest.main()

0 commit comments

Comments
 (0)