From 98a9e7a04313ac6e0f92b0607583c2a1a4085cf8 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 26 Jan 2022 13:43:11 -0800 Subject: [PATCH 01/18] Allow flattening of subcircuits --- cirq-core/cirq/circuits/circuit_operation.py | 7 ++-- .../classically_controlled_operation_test.py | 35 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index c58687e217f..e702036a168 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -108,6 +108,7 @@ class CircuitOperation(ops.Operation): repetition_ids: Optional[List[str]] = dataclasses.field(default=None) parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) + flatten_repetitions: bool = False def __post_init__(self): if not isinstance(self.circuit, circuits.FrozenCircuit): @@ -168,6 +169,7 @@ def __eq__(self, other) -> bool: and self.repetitions == other.repetitions and self.repetition_ids == other.repetition_ids and self.parent_path == other.parent_path + and self.flatten_repetitions == other.flatten_repetitions ) # Methods for getting post-mapping properties of the contained circuit. @@ -190,7 +192,7 @@ def _is_measurement_(self) -> bool: def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: if self._cached_measurement_key_objs is None: circuit_keys = protocols.measurement_key_objs(self.circuit) - if self.repetition_ids is not None: + if self.repetition_ids is not None and not self.flatten_repetitions: circuit_keys = { key.with_key_path_prefix(repetition_id) for repetition_id in self.repetition_ids @@ -251,7 +253,7 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': if self.param_resolver: circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) if self.repetition_ids: - if not protocols.is_measurement(circuit): + if self.flatten_repetitions or not protocols.is_measurement(circuit): circuit = circuit * abs(self.repetitions) else: circuit = circuits.Circuit( @@ -343,6 +345,7 @@ def __hash__(self): self.param_resolver, self.parent_path, tuple([] if self.repetition_ids is None else self.repetition_ids), + self.flatten_repetitions, ) ), ) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 0daf9f327e7..39a80432103 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -482,6 +482,41 @@ def test_scope_local(): assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) +def test_scope_flat(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('a'), + ) + middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True)) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2, flatten_repetitions=True) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['a', 'a', 'a', 'a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]──────────── ]──────────── + [ [ a: ═══@═══^═══ ](loops=2) ](loops=2) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───M───X───M───X───M───X───M───X─── + ║ ║ ║ ║ ║ ║ ║ ║ +a: ═══@═══^═══@═══^═══@═══^═══@═══^═══ +""", + use_unicode_characters=True, + ) + + def test_scope_extern(): q = cirq.LineQubit(0) inner = cirq.Circuit( From ce6cebecb78a21a615fdc37fb80c62da02c82d52 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 15 Feb 2022 15:03:15 -0800 Subject: [PATCH 02/18] format --- .../cirq/ops/classically_controlled_operation_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 39a80432103..89a2e2c00f6 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -488,8 +488,12 @@ def test_scope_flat(): cirq.measure(q, key='a'), cirq.X(q).with_classical_controls('a'), ) - middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True)) - outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2, flatten_repetitions=True) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True) + ) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, flatten_repetitions=True + ) circuit = outer_subcircuit.mapped_circuit(deep=True) internal_control_keys = [ str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) From d75f527cda2f95d94d14c93858761e3ffa85415d Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 15 Feb 2022 15:24:58 -0800 Subject: [PATCH 03/18] Add serialization logic and tests --- cirq-core/cirq/circuits/circuit_operation.py | 12 ++- .../cirq/circuits/circuit_operation_test.py | 20 +++++ .../classically_controlled_operation_test.py | 86 ++++++++++++++++++- .../json_test_data/CircuitOperation.json | 3 +- .../json_test_data/CircuitOperation.repr | 3 +- 5 files changed, 116 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index e702036a168..97dcf118b55 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -297,6 +297,8 @@ def __repr__(self): if self.repetition_ids != self._default_repetition_ids(): # Default repetition_ids need not be specified. args += f'repetition_ids={proper_repr(self.repetition_ids)},\n' + if self.flatten_repetitions: + args += 'flatten_repetitions=True,\n' indented_args = args.replace('\n', '\n ') return f'cirq.CircuitOperation({indented_args[:-4]})' @@ -327,6 +329,8 @@ def dict_str(d: Dict) -> str: elif self.repetitions != 1: # Only add loops if we haven't added repetition_ids. args.append(f'loops={self.repetitions}') + if self.flatten_repetitions: + args.append('flat') if not args: return circuit_msg return f'{circuit_msg}({", ".join(args)})' @@ -352,7 +356,7 @@ def __hash__(self): return self._hash def _json_dict_(self): - return { + resp = { 'circuit': self.circuit, 'repetitions': self.repetitions, # JSON requires mappings to have keys of basic types. @@ -363,6 +367,9 @@ def _json_dict_(self): 'repetition_ids': self.repetition_ids, 'parent_path': self.parent_path, } + if self.flatten_repetitions: + resp['flatten_repetitions'] = True + return resp @classmethod def _from_json_dict_( @@ -374,10 +381,11 @@ def _from_json_dict_( param_resolver, repetition_ids, parent_path=(), + flatten_repetitions=False, **kwargs, ): return ( - cls(circuit) + cls(circuit, flatten_repetitions=flatten_repetitions) .with_qubit_mapping(dict(qubit_map)) .with_measurement_key_mapping(measurement_key_map) .with_params(param_resolver) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 7fbf782ef9b..acf67c76175 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -457,6 +457,26 @@ def test_string_format(): ]), )""" ) + op6 = cirq.CircuitOperation(fc5, flatten_repetitions=True) + assert ( + repr(op6) + == """\ +cirq.CircuitOperation( + circuit=cirq.FrozenCircuit([ + cirq.Moment( + cirq.X(cirq.LineQubit(0)), + cirq.CircuitOperation( + circuit=cirq.FrozenCircuit([ + cirq.Moment( + cirq.X(cirq.LineQubit(1)), + ), + ]), + ), + ), + ]), + flatten_repetitions=True, +)""" + ) def test_json_dict(): diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 89a2e2c00f6..81c6e781fe4 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -482,7 +482,7 @@ def test_scope_local(): assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) -def test_scope_flat(): +def test_scope_flatten_both(): q = cirq.LineQubit(0) inner = cirq.Circuit( cirq.measure(q, key='a'), @@ -504,9 +504,9 @@ def test_scope_flat(): cirq.testing.assert_has_diagram( cirq.Circuit(outer_subcircuit), """ - [ [ 0: ───M───X─── ] ] -0: ───[ 0: ───[ ║ ║ ]──────────── ]──────────── - [ [ a: ═══@═══^═══ ](loops=2) ](loops=2) + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]────────────────── ]────────────────── + [ [ a: ═══@═══^═══ ](loops=2, flat) ](loops=2, flat) """, use_unicode_characters=True, ) @@ -521,6 +521,84 @@ def test_scope_flat(): ) +def test_scope_flatten_inner(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('a'), + ) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True) + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:a', '0:a', '1:a', '1:a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]────────────────── ]──────────── + [ [ a: ═══@═══^═══ ](loops=2, flat) ](loops=2) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ─────M───X───M───X───M───X───M───X─── + ║ ║ ║ ║ ║ ║ ║ ║ +0:a: ═══@═══^═══@═══^═══╬═══╬═══╬═══╬═══ + ║ ║ ║ ║ +1:a: ═══════════════════@═══^═══@═══^═══ +""", + use_unicode_characters=True, + ) + + +def test_scope_flatten_outer(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('a'), + ) + middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, flatten_repetitions=True + ) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:a', '1:a', '0:a', '1:a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]──────────── ]────────────────── + [ [ a: ═══@═══^═══ ](loops=2) ](loops=2, flat) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ─────M───X───M───X───M───X───M───X─── + ║ ║ ║ ║ ║ ║ ║ ║ +0:a: ═══@═══^═══╬═══╬═══@═══^═══╬═══╬═══ + ║ ║ ║ ║ +1:a: ═══════════@═══^═══════════@═══^═══ +""", + use_unicode_characters=True, + ) + + def test_scope_extern(): q = cirq.LineQubit(0) inner = cirq.Circuit( diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json index 73b77c264db..8db21084719 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json @@ -292,7 +292,8 @@ ] }, "parent_path": [], - "repetition_ids": null + "repetition_ids": null, + "flatten_repetitions": true } ] ] diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr index 05010e33fa3..a524a9ce0a4 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr @@ -34,4 +34,5 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ (cirq.X**sympy.Symbol('theta')).on(cirq.LineQubit(0)), ), ]), -param_resolver={sympy.Symbol('theta'): 1.5})] \ No newline at end of file +param_resolver={sympy.Symbol('theta'): 1.5}, +flatten_repetitions=True)] \ No newline at end of file From d7fe99a04df57bffcaf3d951e3e2d4b51c2085c9 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 15 Feb 2022 15:33:03 -0800 Subject: [PATCH 04/18] Change flatten_repetitions (default False) to use_repetition_ids (default True) --- cirq-core/cirq/circuits/circuit_operation.py | 24 +++++++++---------- .../cirq/circuits/circuit_operation_test.py | 4 ++-- .../classically_controlled_operation_test.py | 8 +++---- .../json_test_data/CircuitOperation.json | 2 +- .../json_test_data/CircuitOperation.repr | 2 +- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 97dcf118b55..a94041f6513 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -108,7 +108,7 @@ class CircuitOperation(ops.Operation): repetition_ids: Optional[List[str]] = dataclasses.field(default=None) parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) - flatten_repetitions: bool = False + use_repetition_ids: bool = True def __post_init__(self): if not isinstance(self.circuit, circuits.FrozenCircuit): @@ -169,7 +169,7 @@ def __eq__(self, other) -> bool: and self.repetitions == other.repetitions and self.repetition_ids == other.repetition_ids and self.parent_path == other.parent_path - and self.flatten_repetitions == other.flatten_repetitions + and self.use_repetition_ids == other.use_repetition_ids ) # Methods for getting post-mapping properties of the contained circuit. @@ -192,7 +192,7 @@ def _is_measurement_(self) -> bool: def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: if self._cached_measurement_key_objs is None: circuit_keys = protocols.measurement_key_objs(self.circuit) - if self.repetition_ids is not None and not self.flatten_repetitions: + if self.repetition_ids is not None and self.use_repetition_ids: circuit_keys = { key.with_key_path_prefix(repetition_id) for repetition_id in self.repetition_ids @@ -253,7 +253,7 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': if self.param_resolver: circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) if self.repetition_ids: - if self.flatten_repetitions or not protocols.is_measurement(circuit): + if not self.use_repetition_ids or not protocols.is_measurement(circuit): circuit = circuit * abs(self.repetitions) else: circuit = circuits.Circuit( @@ -297,8 +297,8 @@ def __repr__(self): if self.repetition_ids != self._default_repetition_ids(): # Default repetition_ids need not be specified. args += f'repetition_ids={proper_repr(self.repetition_ids)},\n' - if self.flatten_repetitions: - args += 'flatten_repetitions=True,\n' + if not self.use_repetition_ids: + args += 'use_repetition_ids=False,\n' indented_args = args.replace('\n', '\n ') return f'cirq.CircuitOperation({indented_args[:-4]})' @@ -329,7 +329,7 @@ def dict_str(d: Dict) -> str: elif self.repetitions != 1: # Only add loops if we haven't added repetition_ids. args.append(f'loops={self.repetitions}') - if self.flatten_repetitions: + if not self.use_repetition_ids: args.append('flat') if not args: return circuit_msg @@ -349,7 +349,7 @@ def __hash__(self): self.param_resolver, self.parent_path, tuple([] if self.repetition_ids is None else self.repetition_ids), - self.flatten_repetitions, + self.use_repetition_ids, ) ), ) @@ -367,8 +367,8 @@ def _json_dict_(self): 'repetition_ids': self.repetition_ids, 'parent_path': self.parent_path, } - if self.flatten_repetitions: - resp['flatten_repetitions'] = True + if not self.use_repetition_ids: + resp['use_repetition_ids'] = False return resp @classmethod @@ -381,11 +381,11 @@ def _from_json_dict_( param_resolver, repetition_ids, parent_path=(), - flatten_repetitions=False, + use_repetition_ids=True, **kwargs, ): return ( - cls(circuit, flatten_repetitions=flatten_repetitions) + cls(circuit, use_repetition_ids=use_repetition_ids) .with_qubit_mapping(dict(qubit_map)) .with_measurement_key_mapping(measurement_key_map) .with_params(param_resolver) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index acf67c76175..1a77db5ef8e 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -457,7 +457,7 @@ def test_string_format(): ]), )""" ) - op6 = cirq.CircuitOperation(fc5, flatten_repetitions=True) + op6 = cirq.CircuitOperation(fc5, use_repetition_ids=False) assert ( repr(op6) == """\ @@ -474,7 +474,7 @@ def test_string_format(): ), ), ]), - flatten_repetitions=True, + use_repetition_ids=False, )""" ) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 81c6e781fe4..0ec09229892 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -489,10 +489,10 @@ def test_scope_flatten_both(): cirq.X(q).with_classical_controls('a'), ) middle = cirq.Circuit( - cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True) + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) ) outer_subcircuit = cirq.CircuitOperation( - middle.freeze(), repetitions=2, flatten_repetitions=True + middle.freeze(), repetitions=2, use_repetition_ids=False ) circuit = outer_subcircuit.mapped_circuit(deep=True) internal_control_keys = [ @@ -528,7 +528,7 @@ def test_scope_flatten_inner(): cirq.X(q).with_classical_controls('a'), ) middle = cirq.Circuit( - cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True) + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) ) outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) circuit = outer_subcircuit.mapped_circuit(deep=True) @@ -568,7 +568,7 @@ def test_scope_flatten_outer(): ) middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) outer_subcircuit = cirq.CircuitOperation( - middle.freeze(), repetitions=2, flatten_repetitions=True + middle.freeze(), repetitions=2, use_repetition_ids=False ) circuit = outer_subcircuit.mapped_circuit(deep=True) internal_control_keys = [ diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json index 8db21084719..fe87b08ad67 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json @@ -293,7 +293,7 @@ }, "parent_path": [], "repetition_ids": null, - "flatten_repetitions": true + "use_repetition_ids": false } ] ] diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr index a524a9ce0a4..268baa3f157 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr @@ -35,4 +35,4 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ ), ]), param_resolver={sympy.Symbol('theta'): 1.5}, -flatten_repetitions=True)] \ No newline at end of file +use_repetition_ids=False)] \ No newline at end of file From 6f5b6f084e759d6c51697b7df4f6009ed6ab8971 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 15 Feb 2022 15:50:51 -0800 Subject: [PATCH 05/18] Add shape tests for simulation results from flattened subcircuits --- .../cirq/circuits/circuit_operation_test.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 1a77db5ef8e..e0f4d2cfe8c 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -20,6 +20,13 @@ from cirq.circuits.circuit_operation import _full_join_string_lists +ALL_SIMULATORS = ( + cirq.Simulator(), + cirq.DensityMatrixSimulator(), + cirq.CliffordSimulator(), +) + + def test_properties(): a, b, c = cirq.LineQubit.range(3) circuit = cirq.FrozenCircuit( @@ -878,4 +885,47 @@ def test_mapped_circuit_allows_repeated_keys(): ) +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_simulate_flattened_subcircuit_both_levels(sim): + q = cirq.LineQubit(0) + inner = cirq.Circuit(cirq.measure(q, key='a')) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) + ) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, use_repetition_ids=False + ) + circuit = cirq.Circuit(outer_subcircuit) + result = sim.run(circuit) + assert result.records['a'].shape == (1, 4, 1) + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_simulate_flattened_subcircuit_outer(sim): + q = cirq.LineQubit(0) + inner = cirq.Circuit(cirq.measure(q, key='a')) + middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, use_repetition_ids=False + ) + circuit = cirq.Circuit(outer_subcircuit) + result = sim.run(circuit) + assert result.records['0:a'].shape == (1, 2, 1) + assert result.records['1:a'].shape == (1, 2, 1) + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_simulate_flattened_subcircuit_inner(sim): + q = cirq.LineQubit(0) + inner = cirq.Circuit(cirq.measure(q, key='a')) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = cirq.Circuit(outer_subcircuit) + result = sim.run(circuit) + assert result.records['0:a'].shape == (1, 2, 1) + assert result.records['1:a'].shape == (1, 2, 1) + + # TODO: Operation has a "gate" property. What is this for a CircuitOperation? From 9e5c248518d3935efdecb08f089c637ef9ecc847 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 15 Feb 2022 16:15:24 -0800 Subject: [PATCH 06/18] docs --- cirq-core/cirq/circuits/circuit_operation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index a94041f6513..5df079aaa5d 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -90,6 +90,10 @@ class CircuitOperation(ops.Operation): targets for unbound `ClassicallyControlledOperation` keys. This field is not intended to be set or changed manually, and should be empty in circuits that aren't in the middle of decomposition. + use_repetition_ids: When True, any measurement key in the subcircuit + will have its path prepended with the repetition id for each + repetition. When False, this will not happen and the measurement + key will be repeated. """ _hash: Optional[int] = dataclasses.field(default=None, init=False) From 968773ab86ea067457ae90e9f0d196729044a7b6 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 11 Feb 2022 13:40:15 -0800 Subject: [PATCH 07/18] add repeat_until --- cirq-core/cirq/circuits/circuit_operation.py | 52 +++++++++++++------ .../cirq/circuits/circuit_operation_test.py | 30 +++++++++++ 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 5df079aaa5d..8cb170f28b8 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -113,6 +113,7 @@ class CircuitOperation(ops.Operation): parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) use_repetition_ids: bool = True + repeat_until: Optional['cirq.Condition'] = dataclasses.field(default=None) def __post_init__(self): if not isinstance(self.circuit, circuits.FrozenCircuit): @@ -148,6 +149,10 @@ def __post_init__(self): if q_new.dimension != q.dimension: raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}') + if self.repeat_until: + if self.use_repetition_ids or self.repetition_ids or self.repetitions != 1: + raise ValueError('Cannot use repetition ids with repeat_until') + # Ensure that param_resolver is converted to an actual ParamResolver. object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver)) @@ -235,6 +240,23 @@ def _parameter_names_(self) -> AbstractSet[str]: ) } + def mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit': + circuit = self.circuit.unfreeze() + if self.qubit_map: + circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) + if self.repetitions < 0: + circuit = circuit ** -1 + if self.measurement_key_map: + circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map) + if self.param_resolver: + circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) + if repetition_id: + circuit = protocols.with_rescoped_keys(circuit, (repetition_id,)) + circuit = protocols.with_rescoped_keys( + circuit, self.parent_path, bindable_keys=self.extern_keys + ) + return circuit + def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': """Applies all maps to the contained circuit and returns the result. @@ -247,25 +269,15 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': qubit mapping, parameterization, etc.) applied to it. This behaves like `cirq.decompose(self)`, but preserving moment structure. """ - circuit = self.circuit.unfreeze() - if self.qubit_map: - circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) - if self.repetitions < 0: - circuit = circuit ** -1 - if self.measurement_key_map: - circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map) - if self.param_resolver: - circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) if self.repetition_ids: - if not self.use_repetition_ids or not protocols.is_measurement(circuit): - circuit = circuit * abs(self.repetitions) + if not self.use_repetition_ids or not protocols.is_measurement(self.circuit): + circuit = self.mapped_single_loop() * abs(self.repetitions) else: circuit = circuits.Circuit( - protocols.with_rescoped_keys(circuit, (rep,)) for rep in self.repetition_ids + self.mapped_single_loop(rep) for rep in self.repetition_ids ) - circuit = protocols.with_rescoped_keys( - circuit, self.parent_path, bindable_keys=self.extern_keys - ) + else: + circuit = self.mapped_single_loop() if deep: circuit = circuit.map_operations( lambda op: op.mapped_circuit(deep=True) if isinstance(op, CircuitOperation) else op @@ -280,8 +292,14 @@ def _decompose_(self) -> Iterator['cirq.Operation']: return self.mapped_circuit(deep=False).all_operations() def _act_on_(self, args: 'cirq.OperationTarget') -> bool: - for op in self._decompose_(): - protocols.act_on(op, args) + if self.repeat_until: + circuit = self.mapped_single_loop() + while not self.repeat_until.resolve(args.classical_data): + for op in circuit.all_operations(): + protocols.act_on(op, args) + else: + for op in self._decompose_(): + protocols.act_on(op, args) return True # Methods for string representation of the operation. diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index e0f4d2cfe8c..6f1478332e9 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -928,4 +928,34 @@ def test_simulate_flattened_subcircuit_inner(sim): assert result.records['1:a'].shape == (1, 2, 1) +@pytest.mark.parametrize('sim', [cirq.Simulator(), cirq.DensityMatrixSimulator()]) +def test_repeat_until(sim): + q = cirq.LineQubit(0) + key = cirq.MeasurementKey('m') + c = cirq.Circuit( + cirq.measure(q, key='m'), + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(q) ** 0.2, + cirq.measure(q, key='m'), + ), + use_repetition_ids=False, + repeat_until=cirq.KeyCondition(key), + ), + ) + result = sim.run(c) + assert result.records['m'][0][-1] == (1,) + for i in range(len(result.records['m'][0]) - 1): + assert result.records['m'][0][i] == (0,) + + +def test_repeat_until_error(): + with pytest.raises(ValueError, match='Cannot use repetition ids with repeat_until'): + cirq.CircuitOperation( + cirq.FrozenCircuit(), + use_repetition_ids=True, + repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')), + ) + + # TODO: Operation has a "gate" property. What is this for a CircuitOperation? From 1b4d6ef8748683e72583d3187c5144544725b797 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 16 Feb 2022 11:05:05 -0800 Subject: [PATCH 08/18] repr/json/etc --- cirq-core/cirq/circuits/circuit_operation.py | 59 +++++++++++++------ .../cirq/circuits/circuit_operation_test.py | 59 +++++++++++++++++-- .../json_test_data/CircuitOperation.json | 27 ++++++++- .../json_test_data/CircuitOperation.repr | 18 +++++- 4 files changed, 138 insertions(+), 25 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 8cb170f28b8..cbb742f3e46 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -18,15 +18,16 @@ component operations in order, including any nested CircuitOperations. """ from typing import ( - TYPE_CHECKING, AbstractSet, Callable, + cast, Dict, FrozenSet, Iterator, List, Optional, Tuple, + TYPE_CHECKING, Union, ) @@ -94,6 +95,8 @@ class CircuitOperation(ops.Operation): will have its path prepended with the repetition id for each repetition. When False, this will not happen and the measurement key will be repeated. + do_while: A condition that will be tested prior to each iteration of + the circuit. """ _hash: Optional[int] = dataclasses.field(default=None, init=False) @@ -103,6 +106,9 @@ class CircuitOperation(ops.Operation): _cached_control_keys: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field( default=None, init=False ) + _cached_mapped_single_loop: Optional['cirq.Circuit'] = dataclasses.field( + default=None, init=False + ) circuit: 'cirq.FrozenCircuit' repetitions: int = 1 @@ -113,7 +119,7 @@ class CircuitOperation(ops.Operation): parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) use_repetition_ids: bool = True - repeat_until: Optional['cirq.Condition'] = dataclasses.field(default=None) + do_while: Optional['cirq.Condition'] = dataclasses.field(default=None) def __post_init__(self): if not isinstance(self.circuit, circuits.FrozenCircuit): @@ -149,9 +155,13 @@ def __post_init__(self): if q_new.dimension != q.dimension: raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}') - if self.repeat_until: - if self.use_repetition_ids or self.repetition_ids or self.repetitions != 1: - raise ValueError('Cannot use repetition ids with repeat_until') + if self.do_while: + if self.use_repetition_ids or self.repetitions != 1: + raise ValueError('Cannot use repetitions with do_while') + if protocols.measurement_key_objs(self.mapped_single_loop()).isdisjoint( + self.do_while.keys + ): + raise ValueError('Infinite loop: condition is not modified in subcircuit.') # Ensure that param_resolver is converted to an actual ParamResolver. object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver)) @@ -179,6 +189,7 @@ def __eq__(self, other) -> bool: and self.repetition_ids == other.repetition_ids and self.parent_path == other.parent_path and self.use_repetition_ids == other.use_repetition_ids + and self.do_while == other.do_while ) # Methods for getting post-mapping properties of the contained circuit. @@ -228,6 +239,8 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: if not protocols.control_keys(self.circuit) else protocols.control_keys(self.mapped_circuit()) ) + if self.do_while is not None: + keys |= frozenset(self.do_while.keys) object.__setattr__(self, '_cached_control_keys', keys) return self._cached_control_keys # type: ignore @@ -241,15 +254,20 @@ def _parameter_names_(self) -> AbstractSet[str]: } def mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit': - circuit = self.circuit.unfreeze() - if self.qubit_map: - circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) - if self.repetitions < 0: - circuit = circuit ** -1 - if self.measurement_key_map: - circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map) - if self.param_resolver: - circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) + if self._cached_mapped_single_loop is None: + circuit = self.circuit.unfreeze() + if self.qubit_map: + circuit = circuit.transform_qubits(lambda q: self.qubit_map.get(q, q)) + if self.repetitions < 0: + circuit = circuit ** -1 + if self.measurement_key_map: + circuit = protocols.with_measurement_key_mapping(circuit, self.measurement_key_map) + if self.param_resolver: + circuit = protocols.resolve_parameters( + circuit, self.param_resolver, recursive=False + ) + object.__setattr__(self, '_cached_mapped_single_loop', circuit) + circuit = cast(circuits.Circuit, self._cached_mapped_single_loop) if repetition_id: circuit = protocols.with_rescoped_keys(circuit, (repetition_id,)) circuit = protocols.with_rescoped_keys( @@ -292,9 +310,9 @@ def _decompose_(self) -> Iterator['cirq.Operation']: return self.mapped_circuit(deep=False).all_operations() def _act_on_(self, args: 'cirq.OperationTarget') -> bool: - if self.repeat_until: + if self.do_while: circuit = self.mapped_single_loop() - while not self.repeat_until.resolve(args.classical_data): + while self.do_while.resolve(args.classical_data): for op in circuit.all_operations(): protocols.act_on(op, args) else: @@ -321,6 +339,8 @@ def __repr__(self): args += f'repetition_ids={proper_repr(self.repetition_ids)},\n' if not self.use_repetition_ids: args += 'use_repetition_ids=False,\n' + if self.do_while: + args += f'do_while={self.do_while!r},\n' indented_args = args.replace('\n', '\n ') return f'cirq.CircuitOperation({indented_args[:-4]})' @@ -353,6 +373,8 @@ def dict_str(d: Dict) -> str: args.append(f'loops={self.repetitions}') if not self.use_repetition_ids: args.append('flat') + if self.do_while: + args.append(f'while={self.do_while}') if not args: return circuit_msg return f'{circuit_msg}({", ".join(args)})' @@ -391,6 +413,8 @@ def _json_dict_(self): } if not self.use_repetition_ids: resp['use_repetition_ids'] = False + if self.do_while: + resp['do_while'] = self.do_while return resp @classmethod @@ -404,10 +428,11 @@ def _from_json_dict_( repetition_ids, parent_path=(), use_repetition_ids=True, + do_while=None, **kwargs, ): return ( - cls(circuit, use_repetition_ids=use_repetition_ids) + cls(circuit, use_repetition_ids=use_repetition_ids, do_while=do_while) .with_qubit_mapping(dict(qubit_map)) .with_measurement_key_mapping(measurement_key_map) .with_params(param_resolver) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 6f1478332e9..2a510101713 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -484,6 +484,20 @@ def test_string_format(): use_repetition_ids=False, )""" ) + op7 = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(x, key='a')), use_repetition_ids=False, do_while=cirq.KeyCondition(cirq.MeasurementKey('a'))) + assert ( + repr(op7) + == """\ +cirq.CircuitOperation( + circuit=cirq.FrozenCircuit([ + cirq.Moment( + cirq.measure(cirq.LineQubit(0), key=cirq.MeasurementKey(name='a')), + ), + ]), + use_repetition_ids=False, + do_while=cirq.KeyCondition(cirq.MeasurementKey(name='a')), +)""" + ) def test_json_dict(): @@ -933,6 +947,7 @@ def test_repeat_until(sim): q = cirq.LineQubit(0) key = cirq.MeasurementKey('m') c = cirq.Circuit( + cirq.X(q), cirq.measure(q, key='m'), cirq.CircuitOperation( cirq.FrozenCircuit( @@ -940,21 +955,55 @@ def test_repeat_until(sim): cirq.measure(q, key='m'), ), use_repetition_ids=False, - repeat_until=cirq.KeyCondition(key), + do_while=cirq.KeyCondition(key), ), ) result = sim.run(c) - assert result.records['m'][0][-1] == (1,) + assert result.records['m'][0][-1] == (0,) for i in range(len(result.records['m'][0]) - 1): - assert result.records['m'][0][i] == (0,) + assert result.records['m'][0][i] == (1,) + + +@pytest.mark.parametrize_diagram() +def test_repeat_until(): + q = cirq.LineQubit(0) + key = cirq.MeasurementKey('m') + c = cirq.Circuit( + cirq.X(q), + cirq.measure(q, key='m'), + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(q) ** 0.2, + cirq.measure(q, key='m'), + ), + use_repetition_ids=False, + do_while=cirq.KeyCondition(key), + ), + ) + cirq.testing.assert_has_diagram( + c, + """ +0: ───X───M───[ 0: ───X^0.2───M('m')─── ](flat, while=m)─── + ║ ║ +m: ═══════@═══╩════════════════════════════════════════════ +""", + use_unicode_characters=True, + ) def test_repeat_until_error(): - with pytest.raises(ValueError, match='Cannot use repetition ids with repeat_until'): + q = cirq.LineQubit(0) + with pytest.raises(ValueError, match='Cannot use repetitions with do_while'): cirq.CircuitOperation( cirq.FrozenCircuit(), use_repetition_ids=True, - repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')), + do_while=cirq.KeyCondition(cirq.MeasurementKey('a')), + ) + with pytest.raises(ValueError, match='Infinite loop'): + cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.measure(q, key='m')), + use_repetition_ids=False, + do_while=cirq.KeyCondition(cirq.MeasurementKey('a')), ) diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json index fe87b08ad67..0bd8887645d 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json @@ -292,8 +292,33 @@ ] }, "parent_path": [], + "repetition_ids": null + }, + { + "cirq_type": "CircuitOperation", + "circuit": { + "cirq_type": "_SerializedKey", + "key": 1 + }, + "repetitions": 1, + "qubit_map": [], + "measurement_key_map": {}, + "param_resolver": { + "cirq_type": "ParamResolver", + "param_dict": [] + }, + "parent_path": [], "repetition_ids": null, - "use_repetition_ids": false + "use_repetition_ids": false, + "do_while": { + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "0,1,2,3,4", + "path": [] + }, + "index": -1 + } } ] ] diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr index 268baa3f157..30640e32058 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr @@ -34,5 +34,19 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ (cirq.X**sympy.Symbol('theta')).on(cirq.LineQubit(0)), ), ]), -param_resolver={sympy.Symbol('theta'): 1.5}, -use_repetition_ids=False)] \ No newline at end of file +param_resolver={sympy.Symbol('theta'): 1.5}), +cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ + cirq.Moment( + cirq.H(cirq.LineQubit(0)), + cirq.H(cirq.LineQubit(1)), + cirq.H(cirq.LineQubit(2)), + cirq.H(cirq.LineQubit(3)), + cirq.H(cirq.LineQubit(4)), + ), + cirq.Moment( + cirq.MeasurementGate(5, '0,1,2,3,4', ()).on(cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2), cirq.LineQubit(3), cirq.LineQubit(4)), + ), +]), +use_repetition_ids=False, +do_while=cirq.KeyCondition(key=cirq.MeasurementKey('0,1,2,3,4')), +)] \ No newline at end of file From 999d9d7170dd9b9cdd2b46511272db244d2870ec Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 23 Feb 2022 12:28:30 -0800 Subject: [PATCH 09/18] format --- cirq-core/cirq/circuits/circuit_operation_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 2e1587dcb99..fa2a179ca1c 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -484,7 +484,11 @@ def test_string_format(): use_repetition_ids=False, )""" ) - op7 = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(x, key='a')), use_repetition_ids=False, do_while=cirq.KeyCondition(cirq.MeasurementKey('a'))) + op7 = cirq.CircuitOperation( + cirq.FrozenCircuit(cirq.measure(x, key='a')), + use_repetition_ids=False, + do_while=cirq.KeyCondition(cirq.MeasurementKey('a')), + ) assert ( repr(op7) == """\ From 76fc8d829b9d054084292ce5ed28db3eded226f4 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 23 Feb 2022 12:38:20 -0800 Subject: [PATCH 10/18] chagne do_while to repeat_until --- cirq-core/cirq/circuits/circuit_operation.py | 38 ++++++++++--------- .../cirq/circuits/circuit_operation_test.py | 28 ++++++-------- .../json_test_data/CircuitOperation.json | 7 +--- .../json_test_data/CircuitOperation.repr | 9 +---- 4 files changed, 35 insertions(+), 47 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 2a29003d81a..067d9947ccc 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -95,7 +95,7 @@ class CircuitOperation(ops.Operation): will have its path prepended with the repetition id for each repetition. When False, this will not happen and the measurement key will be repeated. - do_while: A condition that will be tested prior to each iteration of + repeat_until: A condition that will be tested after each iteration of the circuit. """ @@ -119,7 +119,7 @@ class CircuitOperation(ops.Operation): parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) use_repetition_ids: bool = True - do_while: Optional['cirq.Condition'] = dataclasses.field(default=None) + repeat_until: Optional['cirq.Condition'] = dataclasses.field(default=None) def __post_init__(self): if not isinstance(self.circuit, circuits.FrozenCircuit): @@ -155,11 +155,11 @@ def __post_init__(self): if q_new.dimension != q.dimension: raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}') - if self.do_while: + if self.repeat_until: if self.use_repetition_ids or self.repetitions != 1: - raise ValueError('Cannot use repetitions with do_while') + raise ValueError('Cannot use repetitions with repeat_until') if protocols.measurement_key_objs(self.mapped_single_loop()).isdisjoint( - self.do_while.keys + self.repeat_until.keys ): raise ValueError('Infinite loop: condition is not modified in subcircuit.') @@ -189,7 +189,7 @@ def __eq__(self, other) -> bool: and self.repetition_ids == other.repetition_ids and self.parent_path == other.parent_path and self.use_repetition_ids == other.use_repetition_ids - and self.do_while == other.do_while + and self.repeat_until == other.repeat_until ) # Methods for getting post-mapping properties of the contained circuit. @@ -239,8 +239,8 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: if not protocols.control_keys(self.circuit) else protocols.control_keys(self.mapped_circuit()) ) - if self.do_while is not None: - keys |= frozenset(self.do_while.keys) + if self.repeat_until is not None: + keys |= frozenset(self.repeat_until.keys) object.__setattr__(self, '_cached_control_keys', keys) return self._cached_control_keys # type: ignore @@ -310,11 +310,13 @@ def _decompose_(self) -> Iterator['cirq.Operation']: return self.mapped_circuit(deep=False).all_operations() def _act_on_(self, args: 'cirq.OperationTarget') -> bool: - if self.do_while: + if self.repeat_until: circuit = self.mapped_single_loop() - while self.do_while.resolve(args.classical_data): + while True: for op in circuit.all_operations(): protocols.act_on(op, args) + if self.repeat_until.resolve(args.classical_data): + break else: for op in self._decompose_(): protocols.act_on(op, args) @@ -339,8 +341,8 @@ def __repr__(self): args += f'repetition_ids={proper_repr(self.repetition_ids)},\n' if not self.use_repetition_ids: args += 'use_repetition_ids=False,\n' - if self.do_while: - args += f'do_while={self.do_while!r},\n' + if self.repeat_until: + args += f'repeat_until={self.repeat_until!r},\n' indented_args = args.replace('\n', '\n ') return f'cirq.CircuitOperation({indented_args[:-4]})' @@ -373,8 +375,8 @@ def dict_str(d: Dict) -> str: args.append(f'loops={self.repetitions}') if not self.use_repetition_ids: args.append('no_rep_ids') - if self.do_while: - args.append(f'while={self.do_while}') + if self.repeat_until: + args.append(f'until={self.repeat_until}') if not args: return circuit_msg return f'{circuit_msg}({", ".join(args)})' @@ -413,8 +415,8 @@ def _json_dict_(self): } if not self.use_repetition_ids: resp['use_repetition_ids'] = False - if self.do_while: - resp['do_while'] = self.do_while + if self.repeat_until: + resp['repeat_until'] = self.repeat_until return resp @classmethod @@ -428,11 +430,11 @@ def _from_json_dict_( repetition_ids, parent_path=(), use_repetition_ids=True, - do_while=None, + repeat_until=None, **kwargs, ): return ( - cls(circuit, use_repetition_ids=use_repetition_ids, do_while=do_while) + cls(circuit, use_repetition_ids=use_repetition_ids, repeat_until=repeat_until) .with_qubit_mapping(dict(qubit_map)) .with_measurement_key_mapping(measurement_key_map) .with_params(param_resolver) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index fa2a179ca1c..7f81b574769 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -487,7 +487,7 @@ def test_string_format(): op7 = cirq.CircuitOperation( cirq.FrozenCircuit(cirq.measure(x, key='a')), use_repetition_ids=False, - do_while=cirq.KeyCondition(cirq.MeasurementKey('a')), + repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')), ) assert ( repr(op7) @@ -499,7 +499,7 @@ def test_string_format(): ), ]), use_repetition_ids=False, - do_while=cirq.KeyCondition(cirq.MeasurementKey(name='a')), + repeat_until=cirq.KeyCondition(cirq.MeasurementKey(name='a')), )""" ) @@ -951,44 +951,40 @@ def test_repeat_until(sim): q = cirq.LineQubit(0) key = cirq.MeasurementKey('m') c = cirq.Circuit( - cirq.X(q), - cirq.measure(q, key='m'), cirq.CircuitOperation( cirq.FrozenCircuit( cirq.X(q) ** 0.2, cirq.measure(q, key='m'), ), use_repetition_ids=False, - do_while=cirq.KeyCondition(key), + repeat_until=cirq.KeyCondition(key), ), ) result = sim.run(c) - assert result.records['m'][0][-1] == (0,) + assert result.records['m'][0][-1] == (1,) for i in range(len(result.records['m'][0]) - 1): - assert result.records['m'][0][i] == (1,) + assert result.records['m'][0][i] == (0,) def test_repeat_until_diagram(): q = cirq.LineQubit(0) key = cirq.MeasurementKey('m') c = cirq.Circuit( - cirq.X(q), - cirq.measure(q, key='m'), cirq.CircuitOperation( cirq.FrozenCircuit( cirq.X(q) ** 0.2, cirq.measure(q, key='m'), ), use_repetition_ids=False, - do_while=cirq.KeyCondition(key), + repeat_until=cirq.KeyCondition(key), ), ) cirq.testing.assert_has_diagram( c, """ -0: ───X───M───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, while=m)─── - ║ ║ -m: ═══════@═══╩══════════════════════════════════════════════════ +0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)─── + ║ +m: ═══╩══════════════════════════════════════════════════ """, use_unicode_characters=True, ) @@ -996,17 +992,17 @@ def test_repeat_until_diagram(): def test_repeat_until_error(): q = cirq.LineQubit(0) - with pytest.raises(ValueError, match='Cannot use repetitions with do_while'): + with pytest.raises(ValueError, match='Cannot use repetitions with repeat_until'): cirq.CircuitOperation( cirq.FrozenCircuit(), use_repetition_ids=True, - do_while=cirq.KeyCondition(cirq.MeasurementKey('a')), + repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')), ) with pytest.raises(ValueError, match='Infinite loop'): cirq.CircuitOperation( cirq.FrozenCircuit(cirq.measure(q, key='m')), use_repetition_ids=False, - do_while=cirq.KeyCondition(cirq.MeasurementKey('a')), + repeat_until=cirq.KeyCondition(cirq.MeasurementKey('a')), ) diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json index a1ecbab0fc9..4a2a02c18e0 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json @@ -292,7 +292,6 @@ ] }, "parent_path": [], -<<<<<<< HEAD "repetition_ids": null }, { @@ -311,7 +310,7 @@ "parent_path": [], "repetition_ids": null, "use_repetition_ids": false, - "do_while": { + "repeat_until": { "cirq_type": "KeyCondition", "key": { "cirq_type": "MeasurementKey", @@ -320,10 +319,6 @@ }, "index": -1 } -======= - "repetition_ids": null, - "use_repetition_ids": false ->>>>>>> master } ] ] diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr index be058da6893..ceb4eee58a6 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr @@ -35,7 +35,6 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ (cirq.X**sympy.Symbol('theta')).on(cirq.LineQubit(0)), ), ]), -<<<<<<< HEAD param_resolver={sympy.Symbol('theta'): 1.5}), cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ cirq.Moment( @@ -50,9 +49,5 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ ), ]), use_repetition_ids=False, -do_while=cirq.KeyCondition(key=cirq.MeasurementKey('0,1,2,3,4')), -)] -======= -param_resolver={sympy.Symbol('theta'): 1.5}, -use_repetition_ids=False)] ->>>>>>> master +repeat_until=cirq.KeyCondition(key=cirq.MeasurementKey('0,1,2,3,4')), +)] \ No newline at end of file From 1fa06dd79f1baefb18fdde4bfe5d3f5865d23cc5 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 23 Feb 2022 12:44:39 -0800 Subject: [PATCH 11/18] merge fix --- cirq-core/cirq/protocols/json_test_data/CircuitOperation.json | 3 ++- cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json index 4a2a02c18e0..7fbb4421dbe 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json @@ -292,7 +292,8 @@ ] }, "parent_path": [], - "repetition_ids": null + "repetition_ids": null, + "use_repetition_ids": false }, { "cirq_type": "CircuitOperation", diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr index ceb4eee58a6..ee527416d26 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr @@ -35,7 +35,8 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ (cirq.X**sympy.Symbol('theta')).on(cirq.LineQubit(0)), ), ]), -param_resolver={sympy.Symbol('theta'): 1.5}), +param_resolver={sympy.Symbol('theta'): 1.5}, +use_repetition_ids=False), cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ cirq.Moment( cirq.H(cirq.LineQubit(0)), From a2cdf4c0506e3288027db5d23c2eb3a28b8d074a Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 23 Feb 2022 12:48:18 -0800 Subject: [PATCH 12/18] make mapped_single_loop private --- cirq-core/cirq/circuits/circuit_operation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 067d9947ccc..6a3cd506b2c 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -158,7 +158,7 @@ def __post_init__(self): if self.repeat_until: if self.use_repetition_ids or self.repetitions != 1: raise ValueError('Cannot use repetitions with repeat_until') - if protocols.measurement_key_objs(self.mapped_single_loop()).isdisjoint( + if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint( self.repeat_until.keys ): raise ValueError('Infinite loop: condition is not modified in subcircuit.') @@ -253,7 +253,7 @@ def _parameter_names_(self) -> AbstractSet[str]: ) } - def mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit': + def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circuit': if self._cached_mapped_single_loop is None: circuit = self.circuit.unfreeze() if self.qubit_map: @@ -289,13 +289,13 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': """ if self.repetition_ids: if not self.use_repetition_ids or not protocols.is_measurement(self.circuit): - circuit = self.mapped_single_loop() * abs(self.repetitions) + circuit = self._mapped_single_loop() * abs(self.repetitions) else: circuit = circuits.Circuit( - self.mapped_single_loop(rep) for rep in self.repetition_ids + self._mapped_single_loop(rep) for rep in self.repetition_ids ) else: - circuit = self.mapped_single_loop() + circuit = self._mapped_single_loop() if deep: circuit = circuit.map_operations( lambda op: op.mapped_circuit(deep=True) if isinstance(op, CircuitOperation) else op @@ -311,7 +311,7 @@ def _decompose_(self) -> Iterator['cirq.Operation']: def _act_on_(self, args: 'cirq.OperationTarget') -> bool: if self.repeat_until: - circuit = self.mapped_single_loop() + circuit = self._mapped_single_loop() while True: for op in circuit.all_operations(): protocols.act_on(op, args) From 87a0cb8c4fd7159f9c6923a784f4e39370aab8a1 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 25 Feb 2022 11:23:15 -0800 Subject: [PATCH 13/18] Address code review comments. --- cirq-core/cirq/circuits/circuit_operation.py | 8 +++--- .../cirq/circuits/circuit_operation_test.py | 27 ++++++++++++++++--- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 6a3cd506b2c..475d8d03100 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -96,7 +96,11 @@ class CircuitOperation(ops.Operation): repetition. When False, this will not happen and the measurement key will be repeated. repeat_until: A condition that will be tested after each iteration of - the circuit. + the subcircuit. The subcircuit will repeat until condition returns + True, but will always run at least once, and the measurement key + need not be defined prior to the subcircuit (but must be defined in + a measurement within the subcircuit). This field is incompatible + with repetitions or repetition_ids. """ _hash: Optional[int] = dataclasses.field(default=None, init=False) @@ -239,8 +243,6 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: if not protocols.control_keys(self.circuit) else protocols.control_keys(self.mapped_circuit()) ) - if self.repeat_until is not None: - keys |= frozenset(self.repeat_until.keys) object.__setattr__(self, '_cached_control_keys', keys) return self._cached_control_keys # type: ignore diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 7f81b574769..120610e090c 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -946,15 +946,36 @@ def test_simulate_no_repetition_ids_inner(sim): assert result.records['1:a'].shape == (1, 2, 1) -@pytest.mark.parametrize('sim', [cirq.Simulator(), cirq.DensityMatrixSimulator()]) +@pytest.mark.parametrize('sim', ALL_SIMULATORS) def test_repeat_until(sim): + q = cirq.LineQubit(0) + key = cirq.MeasurementKey('m') + c = cirq.Circuit( + cirq.X(q), + cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(q), + cirq.measure(q, key=key), + ), + use_repetition_ids=False, + repeat_until=cirq.KeyCondition(key), + ), + ) + measurements = sim.run(c).records['m'][0] + assert len(measurements) == 2 + assert measurements[0] == (0,) + assert measurements[1] == (1,) + + +@pytest.mark.parametrize('sim', [cirq.Simulator(), cirq.DensityMatrixSimulator()]) +def test_post_selection(sim): q = cirq.LineQubit(0) key = cirq.MeasurementKey('m') c = cirq.Circuit( cirq.CircuitOperation( cirq.FrozenCircuit( cirq.X(q) ** 0.2, - cirq.measure(q, key='m'), + cirq.measure(q, key=key), ), use_repetition_ids=False, repeat_until=cirq.KeyCondition(key), @@ -973,7 +994,7 @@ def test_repeat_until_diagram(): cirq.CircuitOperation( cirq.FrozenCircuit( cirq.X(q) ** 0.2, - cirq.measure(q, key='m'), + cirq.measure(q, key=key), ), use_repetition_ids=False, repeat_until=cirq.KeyCondition(key), From f27868230ca8fc2d0d86245527a7020119d2de1c Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 25 Feb 2022 11:24:37 -0800 Subject: [PATCH 14/18] Fix test --- cirq-core/cirq/circuits/circuit_operation_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 120610e090c..f7534dc74a9 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -1004,8 +1004,6 @@ def test_repeat_until_diagram(): c, """ 0: ───[ 0: ───X^0.2───M('m')─── ](no_rep_ids, until=m)─── - ║ -m: ═══╩══════════════════════════════════════════════════ """, use_unicode_characters=True, ) From 8ca6631f75e32e030a552c8a840b19d9297bcb13 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 27 Feb 2022 18:02:26 -0800 Subject: [PATCH 15/18] simplify branch --- cirq-core/cirq/circuits/circuit_operation.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 475d8d03100..a3ce2a76746 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -289,15 +289,13 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': qubit mapping, parameterization, etc.) applied to it. This behaves like `cirq.decompose(self)`, but preserving moment structure. """ - if self.repetition_ids: - if not self.use_repetition_ids or not protocols.is_measurement(self.circuit): - circuit = self._mapped_single_loop() * abs(self.repetitions) - else: - circuit = circuits.Circuit( - self._mapped_single_loop(rep) for rep in self.repetition_ids - ) - else: - circuit = self._mapped_single_loop() + circuit = ( + circuits.Circuit(self._mapped_single_loop(rep) for rep in self.repetition_ids) + if self.repetition_ids + and self.use_repetition_ids + and protocols.is_measurement(self.circuit) + else self._mapped_single_loop() * abs(self.repetitions) + ) if deep: circuit = circuit.map_operations( lambda op: op.mapped_circuit(deep=True) if isinstance(op, CircuitOperation) else op From 34cb83f805296836165d039f5c450e0bf75c0df4 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 27 Feb 2022 18:03:59 -0800 Subject: [PATCH 16/18] simplify branch --- cirq-core/cirq/circuits/circuit_operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index a3ce2a76746..0203661c466 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -291,7 +291,7 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': """ circuit = ( circuits.Circuit(self._mapped_single_loop(rep) for rep in self.repetition_ids) - if self.repetition_ids + if self.repetition_ids is not None and self.use_repetition_ids and protocols.is_measurement(self.circuit) else self._mapped_single_loop() * abs(self.repetitions) From 351b54275d955ef67bdad232bea308fbeaea78c4 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Sun, 27 Feb 2022 18:07:30 -0800 Subject: [PATCH 17/18] simplify branch --- cirq-core/cirq/circuits/circuit_operation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 0203661c466..89799d55a89 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -272,10 +272,9 @@ def _mapped_single_loop(self, repetition_id: Optional[str] = None) -> 'cirq.Circ circuit = cast(circuits.Circuit, self._cached_mapped_single_loop) if repetition_id: circuit = protocols.with_rescoped_keys(circuit, (repetition_id,)) - circuit = protocols.with_rescoped_keys( + return protocols.with_rescoped_keys( circuit, self.parent_path, bindable_keys=self.extern_keys ) - return circuit def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': """Applies all maps to the contained circuit and returns the result. From a218a697f8f53282ed58b73fc402edf3724048fb Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 28 Feb 2022 16:03:33 -0800 Subject: [PATCH 18/18] add unbound controls in repeat_until to control_keys --- cirq-core/cirq/circuits/circuit_operation.py | 2 ++ .../cirq/circuits/circuit_operation_test.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 7a73694a19f..82efa78d4d8 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -243,6 +243,8 @@ def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: if not protocols.control_keys(self.circuit) else protocols.control_keys(self.mapped_circuit()) ) + if self.repeat_until is not None: + keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_() object.__setattr__(self, '_cached_control_keys', keys) return self._cached_control_keys # type: ignore diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 02f7d53b3bb..385e90ec80a 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -986,6 +986,30 @@ def test_repeat_until(sim): assert measurements[1] == (1,) +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_repeat_until_sympy(sim): + q1, q2 = cirq.LineQubit.range(2) + circuitop = cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.X(q2), + cirq.measure(q2, key='b'), + ), + use_repetition_ids=False, + repeat_until=cirq.SympyCondition(sympy.Eq(sympy.Symbol('a'), sympy.Symbol('b'))), + ) + c = cirq.Circuit( + cirq.measure(q1, key='a'), + circuitop, + ) + # Validate commutation + assert len(c) == 2 + assert cirq.control_keys(circuitop) == {cirq.MeasurementKey('a')} + measurements = sim.run(c).records['b'][0] + assert len(measurements) == 2 + assert measurements[0] == (1,) + assert measurements[1] == (0,) + + @pytest.mark.parametrize('sim', [cirq.Simulator(), cirq.DensityMatrixSimulator()]) def test_post_selection(sim): q = cirq.LineQubit(0)