Skip to content

Commit

Permalink
Do not generate default repetition ids if use_repetition_ids=False
Browse files Browse the repository at this point in the history
Fixes #5418
  • Loading branch information
maffoo committed May 31, 2022
1 parent 6273826 commit 1044718
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
21 changes: 11 additions & 10 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,15 @@ def __post_init__(self):
raise ValueError('repetitions are negative but the circuit is not invertible')

# Initialize repetition_ids to default, if unspecified. Else, validate their length.
loop_size = abs(self.repetitions)
if not self.repetition_ids:
object.__setattr__(self, 'repetition_ids', self._default_repetition_ids())
elif len(self.repetition_ids) != loop_size:
raise ValueError(
f'Expected repetition_ids to be a list of length {loop_size}, '
f'got: {self.repetition_ids}'
)
if self.use_repetition_ids:
loop_size = abs(self.repetitions)
if not self.repetition_ids:
object.__setattr__(self, 'repetition_ids', self._default_repetition_ids())
elif len(self.repetition_ids) != loop_size:
raise ValueError(
f'Expected repetition_ids to be a list of length {loop_size}, '
f'got: {self.repetition_ids}'
)
elif isinstance(self.repetitions, sympy.Expr):
if self.repetition_ids is not None:
raise ValueError('Cannot use repetition ids with parameterized repetitions')
Expand Down Expand Up @@ -377,7 +378,7 @@ def __repr__(self):
args += f'param_resolver={proper_repr(self.param_resolver)},\n'
if self.parent_path:
args += f'parent_path={proper_repr(self.parent_path)},\n'
if self.repetition_ids != self._default_repetition_ids():
if self.use_repetition_ids and (self.repetition_ids != self._default_repetition_ids()):
# Default repetition_ids need not be specified.
args += f'repetition_ids={proper_repr(self.repetition_ids)},\n'
if not self.use_repetition_ids:
Expand Down Expand Up @@ -408,7 +409,7 @@ def dict_str(d: Dict) -> str:
args.append(f'params={self.param_resolver.param_dict}')
if self.parent_path:
args.append(f'parent_path={self.parent_path}')
if self.repetition_ids != self._default_repetition_ids():
if self.use_repetition_ids and (self.repetition_ids != self._default_repetition_ids()):
# Default repetition_ids need not be specified.
args.append(f'repetition_ids={self.repetition_ids}')
elif self.repetitions != 1:
Expand Down
17 changes: 17 additions & 0 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest.mock as mock
from typing import Optional

import numpy as np
import pytest
import sympy

import cirq
import cirq.circuits.circuit_operation as circuit_operation
from cirq.circuits.circuit_operation import _full_join_string_lists

ALL_SIMULATORS = (cirq.Simulator(), cirq.DensityMatrixSimulator(), cirq.CliffordSimulator())
Expand Down Expand Up @@ -346,6 +348,21 @@ def test_repeat_zero_times(add_measurements, use_repetition_ids, initial_reps):
assert np.allclose(result.state_vector(), [1, 0])


def test_no_repetition_ids():
def default_repetition_ids(self):
assert False, "Should not call default_repetition_ids"

with mock.patch.object(circuit_operation, 'default_repetition_ids', new=default_repetition_ids):
q = cirq.LineQubit(0)
op = cirq.CircuitOperation(
cirq.Circuit(cirq.X(q), cirq.measure(q)).freeze(),
repetitions=1_000_000,
use_repetition_ids=False,
)
_ = repr(op)
_ = str(op)


def test_parameterized_repeat():
q = cirq.LineQubit(0)
op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q))) ** sympy.Symbol('a')
Expand Down

0 comments on commit 1044718

Please sign in to comment.