diff --git a/cirq-ft/cirq_ft/algos/qrom.py b/cirq-ft/cirq_ft/algos/qrom.py index 85c3eca1976..2bef77aab05 100644 --- a/cirq-ft/cirq_ft/algos/qrom.py +++ b/cirq-ft/cirq_ft/algos/qrom.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Sequence, Tuple +from typing import Callable, Sequence, Tuple, Set import attr import cirq @@ -29,7 +29,11 @@ class QROM(unary_iteration_gate.UnaryIterationGate): """Gate to load data[l] in the target register when the selection stores an index l. In the case of multi-dimensional data[p,q,r,...] we use multiple named - selection registers [p, q, r, ...] to index and load the data. + selection registers [p, q, r, ...] to index and load the data. Here `p, q, r, ...` + correspond to registers named `selection0`, `selection1`, `selection2`, ... etc. + + When the input data elements contain consecutive entries of identical data elements to + load, the QROM also implements the "variable-spaced" QROM optimization described in Ref[2]. Args: data: List of numpy ndarrays specifying the data to load. If the length @@ -44,6 +48,15 @@ class QROM(unary_iteration_gate.UnaryIterationGate): registers. This can be deduced from the maximum element of each of the datasets. Should be of length len(data), i.e. the number of datasets. num_controls: The number of control registers. + + References: + [Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity] + (https://arxiv.org/abs/1805.03662). + Babbush et. al. (2018). Figure 1. + + [Compilation of Fault-Tolerant Quantum Heuristics for Combinatorial Optimization] + (https://arxiv.org/abs/2007.07391). + Babbush et. al. (2020). Figure 3. """ data: Sequence[NDArray] @@ -152,11 +165,22 @@ def decompose_zero_selection( yield cirq.inverse(multi_controlled_and) context.qubit_manager.qfree(and_ancilla + [and_target]) + def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int): + global_unique_element: Set[int] = set() + for data in self.data: + unique_element = np.unique(data[selection_index_prefix][l:r]) + if len(unique_element) > 1: + return False + global_unique_element.add(unique_element[0]) + if len(global_unique_element) > 1: + return False + return True + def nth_operation( self, context: cirq.DecompositionContext, control: cirq.Qid, **kwargs ) -> cirq.OP_TREE: selection_idx = tuple(kwargs[reg.name] for reg in self.selection_registers) - target_regs = {k: v for k, v in kwargs.items() if k in self.target_registers} + target_regs = {reg.name: kwargs[reg.name] for reg in self.target_registers} yield self._load_nth_data(selection_idx, lambda q: cirq.CNOT(control, q), **target_regs) def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: @@ -172,4 +196,5 @@ def __pow__(self, power: int): return NotImplemented # pragma: no cover def _value_equality_values_(self): - return (self.selection_registers, self.target_registers, self.control_registers) + data_tuple = tuple(tuple(d.flatten()) for d in self.data) + return (self.selection_registers, self.target_registers, self.control_registers, data_tuple) diff --git a/cirq-ft/cirq_ft/algos/qrom_test.py b/cirq-ft/cirq_ft/algos/qrom_test.py index 7cb67bb84d4..39b89963d8b 100644 --- a/cirq-ft/cirq_ft/algos/qrom_test.py +++ b/cirq-ft/cirq_ft/algos/qrom_test.py @@ -116,6 +116,80 @@ def test_t_complexity(data): assert cirq_ft.t_complexity(g.gate).t == max(0, 4 * n - 8), n +def _assert_qrom_has_diagram(qrom: cirq_ft.QROM, expected_diagram: str): + gh = cirq_ft.testing.GateHelper(qrom) + op = gh.operation + context = cirq.DecompositionContext(qubit_manager=cirq_ft.GreedyQubitManager(prefix="anc")) + circuit = cirq.Circuit(cirq.decompose_once(op, context=context)) + selection = [ + *itertools.chain.from_iterable(gh.quregs[reg.name] for reg in qrom.selection_registers) + ] + selection = [q for q in selection if q in circuit.all_qubits()] + anc = sorted(set(circuit.all_qubits()) - set(op.qubits)) + selection_and_anc = (selection[0],) + sum(zip(selection[1:], anc), ()) + qubit_order = cirq.QubitOrder.explicit(selection_and_anc, fallback=cirq.QubitOrder.DEFAULT) + cirq.testing.assert_has_diagram(circuit, expected_diagram, qubit_order=qubit_order) + + +def test_qrom_variable_spacing(): + # Tests for variable spacing optimization applied from https://arxiv.org/abs/2007.07391 + data = [1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8] # Figure 3a. + assert cirq_ft.t_complexity(cirq_ft.QROM.build(data)).t == (8 - 2) * 4 + data = [1, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5] # Figure 3b. + assert cirq_ft.t_complexity(cirq_ft.QROM.build(data)).t == (5 - 2) * 4 + data = [1, 2, 3, 4, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7] # Negative test: t count is not (g-2)*4 + assert cirq_ft.t_complexity(cirq_ft.QROM.build(data)).t == (8 - 2) * 4 + # Works as expected when multiple data arrays are to be loaded. + data = [1, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5] + assert cirq_ft.t_complexity(cirq_ft.QROM.build(data, data)).t == (5 - 2) * 4 + assert cirq_ft.t_complexity(cirq_ft.QROM.build(data, 2 * np.array(data))).t == (16 - 2) * 4 + # Works as expected when multidimensional input data is to be loaded + qrom = cirq_ft.QROM.build( + np.array( + [ + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2, 2, 2], + [2, 2, 2, 2, 2, 2, 2, 2], + ] + ) + ) + # Value to be loaded depends only the on the first bit of outer loop. + _assert_qrom_has_diagram( + qrom, + r''' +selection00: ───X───@───X───@─── + │ │ +target00: ──────────┼───────X─── + │ +target01: ──────────X─────────── + ''', + ) + # When inner loop range is not a power of 2, the inner segment tree cannot be skipped. + qrom = cirq_ft.QROM.build( + np.array( + [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [2, 2, 2, 2, 2, 2]], + dtype=int, + ) + ) + _assert_qrom_has_diagram( + qrom, + r''' +selection00: ───X───@─────────@───────@──────X───@─────────@───────@────── + │ │ │ │ │ │ +selection10: ───────(0)───────┼───────@──────────(0)───────┼───────@────── + │ │ │ │ │ │ +anc_1: ─────────────And───@───X───@───And†───────And───@───X───@───And†─── + │ │ │ │ +target00: ────────────────┼───────┼────────────────────X───────X────────── + │ │ +target01: ────────────────X───────X─────────────────────────────────────── + ''', + ) + # No T-gates needed if all elements to load are identical. + assert cirq_ft.t_complexity(cirq_ft.QROM.build([3, 3, 3, 3])).t == 0 + + @pytest.mark.parametrize( "data", [[np.arange(6).reshape(2, 3), 4 * np.arange(6).reshape(2, 3)], [np.arange(8).reshape(2, 2, 2)]], diff --git a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py index 4ff83080100..f37c2718b3b 100644 --- a/cirq-ft/cirq_ft/algos/unary_iteration_gate.py +++ b/cirq-ft/cirq_ft/algos/unary_iteration_gate.py @@ -13,7 +13,7 @@ # limitations under the License. import abc -from typing import Dict, Iterator, List, Sequence, Tuple +from typing import Callable, Dict, Iterator, List, Sequence, Tuple from numpy.typing import NDArray import cirq @@ -34,6 +34,7 @@ def _unary_iteration_segtree( r: int, l_iter: int, r_iter: int, + break_early: Callable[[int, int], bool], ) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]: """Constructs a unary iteration circuit by iterating over nodes of an implicit Segment Tree. @@ -53,6 +54,11 @@ def _unary_iteration_segtree( r: Right index of the range represented by current node of the segment tree. l_iter: Left index of iteration range over which the segment tree should be constructed. r_iter: Right index of iteration range over which the segment tree should be constructed. + break_early: For each internal node of the segment tree, `break_early(l, r)` is called to + evaluate whether the unary iteration should terminate early and not recurse in the + subtree of the node representing range `[l, r)`. If True, the internal node is + considered equivalent to a leaf node and the method yields only one tuple + `(OP_TREE, control_qubit, l)` for all integers in the range `[l, r)`. Yields: One `Tuple[cirq.OP_TREE, cirq.Qid, int]` for each leaf node in the segment tree. The i'th @@ -68,8 +74,8 @@ def _unary_iteration_segtree( if l >= r_iter or l_iter >= r: # Range corresponding to this node is completely outside of iteration range. return - if l == (r - 1): - # Reached a leaf node; yield the operations. + if l_iter <= l < r <= r_iter and (l == (r - 1) or break_early(l, r)): + # Reached a leaf node or a "special" internal node; yield the operations. yield tuple(ops), control, l ops.clear() return @@ -78,20 +84,24 @@ def _unary_iteration_segtree( if r_iter <= m: # Yield only left sub-tree. yield from _unary_iteration_segtree( - ops, control, selection, ancilla, sl + 1, l, m, l_iter, r_iter + ops, control, selection, ancilla, sl + 1, l, m, l_iter, r_iter, break_early ) return if l_iter >= m: # Yield only right sub-tree yield from _unary_iteration_segtree( - ops, control, selection, ancilla, sl + 1, m, r, l_iter, r_iter + ops, control, selection, ancilla, sl + 1, m, r, l_iter, r_iter, break_early ) return anc, sq = ancilla[sl], selection[sl] ops.append(and_gate.And((1, 0)).on(control, sq, anc)) - yield from _unary_iteration_segtree(ops, anc, selection, ancilla, sl + 1, l, m, l_iter, r_iter) + yield from _unary_iteration_segtree( + ops, anc, selection, ancilla, sl + 1, l, m, l_iter, r_iter, break_early + ) ops.append(cirq.CNOT(control, anc)) - yield from _unary_iteration_segtree(ops, anc, selection, ancilla, sl + 1, m, r, l_iter, r_iter) + yield from _unary_iteration_segtree( + ops, anc, selection, ancilla, sl + 1, m, r, l_iter, r_iter, break_early + ) ops.append(and_gate.And(adjoint=True).on(control, sq, anc)) @@ -101,16 +111,17 @@ def _unary_iteration_zero_control( ancilla: Sequence[cirq.Qid], l_iter: int, r_iter: int, + break_early: Callable[[int, int], bool], ) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]: sl, l, r = 0, 0, 2 ** len(selection) m = (l + r) >> 1 ops.append(cirq.X(selection[0])) yield from _unary_iteration_segtree( - ops, selection[0], selection[1:], ancilla, sl, l, m, l_iter, r_iter + ops, selection[0], selection[1:], ancilla, sl, l, m, l_iter, r_iter, break_early ) ops.append(cirq.X(selection[0])) yield from _unary_iteration_segtree( - ops, selection[0], selection[1:], ancilla, sl, m, r, l_iter, r_iter + ops, selection[0], selection[1:], ancilla, sl, m, r, l_iter, r_iter, break_early ) @@ -121,9 +132,12 @@ def _unary_iteration_single_control( ancilla: Sequence[cirq.Qid], l_iter: int, r_iter: int, + break_early: Callable[[int, int], bool], ) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]: sl, l, r = 0, 0, 2 ** len(selection) - yield from _unary_iteration_segtree(ops, control, selection, ancilla, sl, l, r, l_iter, r_iter) + yield from _unary_iteration_segtree( + ops, control, selection, ancilla, sl, l, r, l_iter, r_iter, break_early + ) def _unary_iteration_multi_controls( @@ -133,6 +147,7 @@ def _unary_iteration_multi_controls( ancilla: Sequence[cirq.Qid], l_iter: int, r_iter: int, + break_early: Callable[[int, int], bool], ) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]: num_controls = len(controls) and_ancilla = ancilla[: num_controls - 2] @@ -142,7 +157,7 @@ def _unary_iteration_multi_controls( ) ops.append(multi_controlled_and) yield from _unary_iteration_single_control( - ops, and_target, selection, ancilla[num_controls - 1 :], l_iter, r_iter + ops, and_target, selection, ancilla[num_controls - 1 :], l_iter, r_iter, break_early ) ops.append(cirq.inverse(multi_controlled_and)) @@ -154,6 +169,7 @@ def unary_iteration( controls: Sequence[cirq.Qid], selection: Sequence[cirq.Qid], qubit_manager: cirq.QubitManager, + break_early: Callable[[int, int], bool] = lambda l, r: False, ) -> Iterator[Tuple[cirq.OP_TREE, cirq.Qid, int]]: """The method performs unary iteration on `selection` integer in `range(l_iter, r_iter)`. @@ -181,6 +197,9 @@ def unary_iteration( ... circuit.append(j_ops) >>> circuit.append(i_ops) + Note: Unary iteration circuits assume that the selection register stores integers only in the + range `[l, r)` for which the corresponding unary iteration circuit should be built. + Args: l_iter: Starting index of the iteration range. r_iter: Ending index of the iteration range. @@ -192,6 +211,11 @@ def unary_iteration( controls: Control register of qubits. selection: Selection register of qubits. qubit_manager: A `cirq.QubitManager` to allocate new qubits. + break_early: For each internal node of the segment tree, `break_early(l, r)` is called to + evaluate whether the unary iteration should terminate early and not recurse in the + subtree of the node representing range `[l, r)`. If True, the internal node is + considered equivalent to a leaf node and the method yields only one tuple + `(OP_TREE, control_qubit, l)` for all integers in the range `[l, r)`. Yields: (r_iter - l_iter) different tuples, each corresponding to an integer in range @@ -207,14 +231,16 @@ def unary_iteration( assert len(selection) > 0 ancilla = qubit_manager.qalloc(max(0, len(controls) + len(selection) - 1)) if len(controls) == 0: - yield from _unary_iteration_zero_control(flanking_ops, selection, ancilla, l_iter, r_iter) + yield from _unary_iteration_zero_control( + flanking_ops, selection, ancilla, l_iter, r_iter, break_early + ) elif len(controls) == 1: yield from _unary_iteration_single_control( - flanking_ops, controls[0], selection, ancilla, l_iter, r_iter + flanking_ops, controls[0], selection, ancilla, l_iter, r_iter, break_early ) else: yield from _unary_iteration_multi_controls( - flanking_ops, controls, selection, ancilla, l_iter, r_iter + flanking_ops, controls, selection, ancilla, l_iter, r_iter, break_early ) qubit_manager.qfree(ancilla) @@ -231,6 +257,9 @@ class UnaryIterationGate(infra.GateWithRegisters): indexed operations on a target register depending on the index value stored in a selection register. + Note: Unary iteration circuits assume that the selection register stores integers only in the + range `[l, r)` for which the corresponding unary iteration circuit should be built. + References: [Encoding Electronic Spectra in Quantum Circuits with Linear T Complexity] (https://arxiv.org/abs/1805.03662). @@ -308,10 +337,38 @@ def decompose_zero_selection( """ raise NotImplementedError("Selection register must not be empty.") + def _break_early(self, selection_index_prefix: Tuple[int, ...], l: int, r: int) -> bool: + """Derived classes should override this method to specify an early termination condition. + + For each internal node of the unary iteration segment tree, `break_early(l, r)` is called + to evaluate whether the unary iteration should not recurse in the subtree of the node + representing range `[l, r)`. If True, the internal node is considered equivalent to a leaf + node and thus, `self.nth_operation` will be called for only integer `l` in the range [l, r). + + When the `UnaryIteration` class is constructed using multiple selection registers, i.e. we + wish to perform nested coherent for-loops, a unary iteration segment tree is constructed + corresponding to each nested coherent for-loop. For every such unary iteration segment tree, + the `_break_early` condition is checked by passing the `selection_index_prefix` tuple. + + Args: + selection_index_prefix: To evaluate the early breaking condition for the i'th nested + for-loop, the `selection_index_prefix` contains `i-1` integers corresponding to + the loop variable values for the first `i-1` nested loops. + l: Beginning of range `[l, r)` for internal node of unary iteration segment tree. + r: End (exclusive) of range `[l, r)` for internal node of unary iteration segment tree. + + Returns: + True of the `len(selection_index_prefix)`'th unary iteration should terminate early for + the given parameters. + """ + return False + def decompose_from_registers( self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] ) -> cirq.OP_TREE: - if self.selection_registers.total_bits() == 0: + if self.selection_registers.total_bits() == 0 or self._break_early( + (), 0, self.selection_registers[0].iteration_length + ): return self.decompose_zero_selection(context=context, **quregs) num_loops = len(self.selection_registers) @@ -354,6 +411,7 @@ def unary_iteration_loops( return # Use recursion to write `num_loops` nested loops using unary_iteration(). ops: List[cirq.Operation] = [] + selection_index_prefix = tuple(selection_reg_name_to_val.values()) ith_for_loop = unary_iteration( l_iter=0, r_iter=self.selection_registers[nested_depth].iteration_length, @@ -361,6 +419,7 @@ def unary_iteration_loops( controls=controls, selection=[*quregs[self.selection_registers[nested_depth].name]], qubit_manager=context.qubit_manager, + break_early=lambda l, r: self._break_early(selection_index_prefix, l, r), ) for op_tree, control_qid, n in ith_for_loop: yield op_tree @@ -368,6 +427,7 @@ def unary_iteration_loops( yield from unary_iteration_loops( nested_depth + 1, selection_reg_name_to_val, (control_qid,) ) + selection_reg_name_to_val.pop(self.selection_registers[nested_depth].name) yield ops return unary_iteration_loops(0, {}, self.control_registers.merge_qubits(**quregs)) diff --git a/cirq-ft/cirq_ft/infra/qubit_manager.py b/cirq-ft/cirq_ft/infra/qubit_manager.py index 098e6231c35..59e4bbded13 100644 --- a/cirq-ft/cirq_ft/infra/qubit_manager.py +++ b/cirq-ft/cirq_ft/infra/qubit_manager.py @@ -74,9 +74,9 @@ def qalloc(self, n: int, dim: int = 2) -> List[cirq.Qid]: return ret_qubits def qfree(self, qubits: Iterable[cirq.Qid]) -> None: - qs = set(qubits) + qs = list(dict(zip(qubits, qubits)).keys()) assert self._used_qubits.issuperset(qs), "Only managed qubits currently in-use can be freed" - self._used_qubits -= qs + self._used_qubits = self._used_qubits.difference(qs) self._free_qubits.extend(qs) def qborrow(self, n: int, dim: int = 2) -> List[cirq.Qid]: diff --git a/cirq-ft/cirq_ft/infra/qubit_manager_test.py b/cirq-ft/cirq_ft/infra/qubit_manager_test.py index bbfc6e10dd2..5eda575958a 100644 --- a/cirq-ft/cirq_ft/infra/qubit_manager_test.py +++ b/cirq-ft/cirq_ft/infra/qubit_manager_test.py @@ -92,3 +92,11 @@ def make_circuit(qm: cirq.QubitManager): ancilla_1: ───X───X─── """, ) + + +def test_greedy_qubit_manager_preserves_order(): + qm = cirq_ft.GreedyQubitManager(prefix="anc") + ancillae = [cirq.q(f"anc_{i}") for i in range(100)] + assert qm.qalloc(100) == ancillae + qm.qfree(ancillae) + assert qm.qalloc(100) == ancillae