|
14 | 14 |
|
15 | 15 | """mpl circuit visualization backend."""
|
16 | 16 |
|
| 17 | +import collections |
17 | 18 | import itertools
|
18 | 19 | import re
|
19 | 20 | from warnings import warn
|
|
33 | 34 | ForLoopOp,
|
34 | 35 | SwitchCaseOp,
|
35 | 36 | )
|
| 37 | +from qiskit.circuit.controlflow import condition_resources |
| 38 | +from qiskit.circuit.classical import expr |
36 | 39 | from qiskit.circuit.library.standard_gates import (
|
37 | 40 | SwapGate,
|
38 | 41 | RZZGate,
|
@@ -1090,45 +1093,66 @@ def _condition(self, node, node_data, wire_map, cond_xy, glob_data):
|
1090 | 1093 | # For SwitchCaseOp convert the target to a fully closed Clbit or register
|
1091 | 1094 | # in condition format
|
1092 | 1095 | if isinstance(node.op, SwitchCaseOp):
|
1093 |
| - if isinstance(node.op.target, Clbit): |
| 1096 | + if isinstance(node.op.target, expr.Expr): |
| 1097 | + condition = node.op.target |
| 1098 | + elif isinstance(node.op.target, Clbit): |
1094 | 1099 | condition = (node.op.target, 1)
|
1095 | 1100 | else:
|
1096 | 1101 | condition = (node.op.target, 2 ** (node.op.target.size) - 1)
|
1097 | 1102 | else:
|
1098 | 1103 | condition = node.op.condition
|
1099 |
| - label, val_bits = get_condition_label_val(condition, self._circuit, self._cregbundle) |
1100 |
| - cond_bit_reg = condition[0] |
1101 |
| - cond_bit_val = int(condition[1]) |
1102 | 1104 |
|
| 1105 | + override_fc = False |
1103 | 1106 | first_clbit = len(self._qubits)
|
1104 | 1107 | cond_pos = []
|
1105 | 1108 |
|
1106 |
| - # In the first case, multiple bits are indicated on the drawing. In all |
1107 |
| - # other cases, only one bit is shown. |
1108 |
| - if not self._cregbundle and isinstance(cond_bit_reg, ClassicalRegister): |
1109 |
| - for idx in range(cond_bit_reg.size): |
1110 |
| - cond_pos.append(cond_xy[wire_map[cond_bit_reg[idx]] - first_clbit]) |
1111 |
| - |
1112 |
| - # If it's a register bit and cregbundle, need to use the register to find the location |
1113 |
| - elif self._cregbundle and isinstance(cond_bit_reg, Clbit): |
1114 |
| - register = get_bit_register(self._circuit, cond_bit_reg) |
1115 |
| - if register is not None: |
1116 |
| - cond_pos.append(cond_xy[wire_map[register] - first_clbit]) |
| 1109 | + if isinstance(condition, expr.Expr): |
| 1110 | + # If fixing this, please update the docstrings of `QuantumCircuit.draw` and |
| 1111 | + # `visualization.circuit_drawer` to remove warnings. |
| 1112 | + condition_bits = condition_resources(condition).clbits |
| 1113 | + label = "[expression]" |
| 1114 | + override_fc = True |
| 1115 | + registers = collections.defaultdict(list) |
| 1116 | + for bit in condition_bits: |
| 1117 | + registers[get_bit_register(self._circuit, bit)].append(bit) |
| 1118 | + # Registerless bits don't care whether cregbundle is set. |
| 1119 | + cond_pos.extend(cond_xy[wire_map[bit] - first_clbit] for bit in registers.pop(None, ())) |
| 1120 | + if self._cregbundle: |
| 1121 | + cond_pos.extend( |
| 1122 | + cond_xy[wire_map[register[0]] - first_clbit] for register in registers |
| 1123 | + ) |
1117 | 1124 | else:
|
1118 |
| - cond_pos.append(cond_xy[wire_map[cond_bit_reg] - first_clbit]) |
| 1125 | + cond_pos.extend( |
| 1126 | + cond_xy[wire_map[bit] - first_clbit] |
| 1127 | + for register, bits in registers.items() |
| 1128 | + for bit in bits |
| 1129 | + ) |
| 1130 | + val_bits = ["1"] * len(cond_pos) |
1119 | 1131 | else:
|
1120 |
| - cond_pos.append(cond_xy[wire_map[cond_bit_reg] - first_clbit]) |
| 1132 | + label, val_bits = get_condition_label_val(condition, self._circuit, self._cregbundle) |
| 1133 | + cond_bit_reg = condition[0] |
| 1134 | + cond_bit_val = int(condition[1]) |
| 1135 | + override_fc = cond_bit_val != 0 |
| 1136 | + |
| 1137 | + # In the first case, multiple bits are indicated on the drawing. In all |
| 1138 | + # other cases, only one bit is shown. |
| 1139 | + if not self._cregbundle and isinstance(cond_bit_reg, ClassicalRegister): |
| 1140 | + for idx in range(cond_bit_reg.size): |
| 1141 | + cond_pos.append(cond_xy[wire_map[cond_bit_reg[idx]] - first_clbit]) |
| 1142 | + |
| 1143 | + # If it's a register bit and cregbundle, need to use the register to find the location |
| 1144 | + elif self._cregbundle and isinstance(cond_bit_reg, Clbit): |
| 1145 | + register = get_bit_register(self._circuit, cond_bit_reg) |
| 1146 | + if register is not None: |
| 1147 | + cond_pos.append(cond_xy[wire_map[register] - first_clbit]) |
| 1148 | + else: |
| 1149 | + cond_pos.append(cond_xy[wire_map[cond_bit_reg] - first_clbit]) |
| 1150 | + else: |
| 1151 | + cond_pos.append(cond_xy[wire_map[cond_bit_reg] - first_clbit]) |
1121 | 1152 |
|
1122 | 1153 | xy_plot = []
|
1123 |
| - for idx, xy in enumerate(cond_pos): |
1124 |
| - if val_bits[idx] == "1" or ( |
1125 |
| - isinstance(cond_bit_reg, ClassicalRegister) |
1126 |
| - and cond_bit_val != 0 |
1127 |
| - and self._cregbundle |
1128 |
| - ): |
1129 |
| - fc = self._style["lc"] |
1130 |
| - else: |
1131 |
| - fc = self._style["bg"] |
| 1154 | + for val_bit, xy in zip(val_bits, cond_pos): |
| 1155 | + fc = self._style["lc"] if override_fc or val_bit == "1" else self._style["bg"] |
1132 | 1156 | box = glob_data["patches_mod"].Circle(
|
1133 | 1157 | xy=xy,
|
1134 | 1158 | radius=WID * 0.15,
|
|
0 commit comments