Skip to content

Commit

Permalink
Do not generate default repetition ids if use_repetition_ids=False (q…
Browse files Browse the repository at this point in the history
  • Loading branch information
maffoo authored May 31, 2022
1 parent ad0f1ae commit 14b887c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
13 changes: 7 additions & 6 deletions cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ def default_repetition_ids(repetitions: IntParam) -> Optional[List[str]]:
return None


def _full_join_string_lists(list1: Optional[List[str]], list2: Optional[List[str]]):
def _full_join_string_lists(
list1: Optional[List[str]], list2: Optional[List[str]]
) -> Optional[List[str]]:
if list1 is None and list2 is None:
return None # coverage: ignore
if list1 is None:
return list2 # coverage: ignore
if list2 is None:
return list1
return [
f'{REPETITION_ID_SEPARATOR.join([first, second])}' for first in list1 for second in list2
]
return [f'{first}{REPETITION_ID_SEPARATOR}{second}' for first in list1 for second in list2]


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -224,7 +224,7 @@ def qubits(self) -> Tuple['cirq.Qid', ...]:
return tuple(self.qubit_map.get(q, q) for q in ordered_qubits)

def _default_repetition_ids(self) -> Optional[List[str]]:
return default_repetition_ids(self.repetitions)
return default_repetition_ids(self.repetitions) if self.use_repetition_ids else None

def _qid_shape_(self) -> Tuple[int, ...]:
return tuple(q.dimension for q in self.qubits)
Expand Down Expand Up @@ -524,7 +524,8 @@ def repeat(
expected_repetition_id_length = abs(repetitions)

if repetition_ids is None:
repetition_ids = default_repetition_ids(expected_repetition_id_length)
if self.use_repetition_ids:
repetition_ids = default_repetition_ids(expected_repetition_id_length)
elif len(repetition_ids) != expected_repetition_id_length:
raise ValueError(
f'Expected repetition_ids={repetition_ids} length to be '
Expand Down
23 changes: 23 additions & 0 deletions cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest.mock as mock
from typing import Optional

import numpy as np
import pytest
import sympy

import cirq
import cirq.circuits.circuit_operation as circuit_operation
from cirq.circuits.circuit_operation import _full_join_string_lists

ALL_SIMULATORS = (cirq.Simulator(), cirq.DensityMatrixSimulator(), cirq.CliffordSimulator())
Expand Down Expand Up @@ -346,6 +348,27 @@ def test_repeat_zero_times(add_measurements, use_repetition_ids, initial_reps):
assert np.allclose(result.state_vector(), [1, 0])


def test_no_repetition_ids():
def default_repetition_ids(self):
assert False, "Should not call default_repetition_ids"

with mock.patch.object(circuit_operation, 'default_repetition_ids', new=default_repetition_ids):
q = cirq.LineQubit(0)
op = cirq.CircuitOperation(
cirq.Circuit(cirq.X(q), cirq.measure(q)).freeze(),
repetitions=1_000_000,
use_repetition_ids=False,
)
assert op.repetitions == 1_000_000
assert op.repetition_ids is None
_ = repr(op)
_ = str(op)

op2 = op.repeat(10)
assert op2.repetitions == 10_000_000
assert op2.repetition_ids is None


def test_parameterized_repeat():
q = cirq.LineQubit(0)
op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q))) ** sympy.Symbol('a')
Expand Down

0 comments on commit 14b887c

Please sign in to comment.