diff --git a/qualtran/bloqs/data_loading/qroam_clean.py b/qualtran/bloqs/data_loading/qroam_clean.py index 4dedb5be3..2331782e5 100644 --- a/qualtran/bloqs/data_loading/qroam_clean.py +++ b/qualtran/bloqs/data_loading/qroam_clean.py @@ -14,6 +14,7 @@ import numbers from collections import defaultdict from functools import cached_property +import re from typing import cast, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union import attrs @@ -196,7 +197,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) - if reg is None: return Text('QROAM').adjoint() name = reg.name - if name == 'selection': + if name.startswith('selection'): return TextBox('In').adjoint() elif 'target' in name: trg_indx = int(name.replace('target', '').replace('_', '')) @@ -283,10 +284,11 @@ def with_log_block_sizes( def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol': if reg is None: return Text('QROAM').adjoint() - name = reg.name - if name == 'selection': + # Find the last instance '_' in the register name to split at. + name = reg.name + if name.startswith('selection'): return TextBox('In') - elif 'target' in name: + elif 'target' in name and 'junk' not in name: trg_indx = int(name.replace('target', '').replace('_', '')) # match the sel index subscript = chr(ord('a') + trg_indx) @@ -528,7 +530,7 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) - if reg is None: return Text('QROAM') name = reg.name - if name == 'selection': + if name.startswith('selection'): return TextBox('In') elif 'target' in name and 'junk' not in name: trg_indx = int(name.replace('target', '').replace('_', '')) diff --git a/qualtran/bloqs/data_loading/qroam_clean_test.py b/qualtran/bloqs/data_loading/qroam_clean_test.py index 1f8cd0f0a..232746065 100644 --- a/qualtran/bloqs/data_loading/qroam_clean_test.py +++ b/qualtran/bloqs/data_loading/qroam_clean_test.py @@ -15,11 +15,21 @@ import pytest import sympy +from qualtran._infra.data_types import QAny +from qualtran._infra.registers import Register +from qualtran.drawing.musical_score import ( + Text, + TextBox, + RarrowTextBox, + LarrowTextBox, + Circle +) from qualtran.bloqs.data_loading.qroam_clean import ( _qroam_clean_multi_data, _qroam_clean_multi_dim, get_optimal_log_block_size_clean_ancilla, QROAMClean, + QROAMCleanAdjoint, QROAMCleanAdjointWrapper, ) from qualtran.resource_counting import get_cost_value, QubitCount @@ -30,6 +40,38 @@ def test_bloq_examples(bloq_autotester): bloq_autotester(_qroam_clean_multi_data) bloq_autotester(_qroam_clean_multi_dim) +@pytest.mark.parametrize( + "reg, reg_type", + [ + (None, Text), + (Register("selection", QAny(5)), TextBox), + (Register("selection0", QAny(5)), TextBox), + (Register("target0_", QAny(5)), RarrowTextBox), + (Register("junk_target0_", QAny(5)), RarrowTextBox), + (Register("control", QAny(5)), Circle), + ], +) +def test_wire_symbol(reg, reg_type): + bloq = _qroam_clean_multi_dim.make() + assert isinstance(bloq.wire_symbol(reg, ()), reg_type) + assert isinstance(bloq.adjoint().wire_symbol(reg, ()), reg_type) + + +@pytest.mark.parametrize( + "reg, reg_type", + [ + (None, Text), + (Register("selection", QAny(5)), TextBox), + (Register("selection0", QAny(5)), TextBox), + (Register("target0_", QAny(5)), LarrowTextBox), + (Register("control", QAny(5)), Circle), + ], +) +def test_adjoint_wire_symbol(reg, reg_type): + data1 = np.arange(25, dtype=int).reshape((5, 5)) + data2 = (np.arange(25, dtype=int) + 1).reshape((5, 5)) + adjoint_bloq = QROAMCleanAdjoint.build_from_data(data1, data2, log_block_sizes=(1, 1)) + assert isinstance(adjoint_bloq.wire_symbol(reg, ()), reg_type) def test_qroam_clean_qubit_counts(): bloq = _qroam_clean_multi_data.make()