From 98a9e7a04313ac6e0f92b0607583c2a1a4085cf8 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 26 Jan 2022 13:43:11 -0800 Subject: [PATCH 1/8] Allow flattening of subcircuits --- cirq-core/cirq/circuits/circuit_operation.py | 7 ++-- .../classically_controlled_operation_test.py | 35 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index c58687e217f..e702036a168 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -108,6 +108,7 @@ class CircuitOperation(ops.Operation): repetition_ids: Optional[List[str]] = dataclasses.field(default=None) parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) + flatten_repetitions: bool = False def __post_init__(self): if not isinstance(self.circuit, circuits.FrozenCircuit): @@ -168,6 +169,7 @@ def __eq__(self, other) -> bool: and self.repetitions == other.repetitions and self.repetition_ids == other.repetition_ids and self.parent_path == other.parent_path + and self.flatten_repetitions == other.flatten_repetitions ) # Methods for getting post-mapping properties of the contained circuit. @@ -190,7 +192,7 @@ def _is_measurement_(self) -> bool: def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: if self._cached_measurement_key_objs is None: circuit_keys = protocols.measurement_key_objs(self.circuit) - if self.repetition_ids is not None: + if self.repetition_ids is not None and not self.flatten_repetitions: circuit_keys = { key.with_key_path_prefix(repetition_id) for repetition_id in self.repetition_ids @@ -251,7 +253,7 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': if self.param_resolver: circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) if self.repetition_ids: - if not protocols.is_measurement(circuit): + if self.flatten_repetitions or not protocols.is_measurement(circuit): circuit = circuit * abs(self.repetitions) else: circuit = circuits.Circuit( @@ -343,6 +345,7 @@ def __hash__(self): self.param_resolver, self.parent_path, tuple([] if self.repetition_ids is None else self.repetition_ids), + self.flatten_repetitions, ) ), ) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 0daf9f327e7..39a80432103 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -482,6 +482,41 @@ def test_scope_local(): assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) +def test_scope_flat(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('a'), + ) + middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True)) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2, flatten_repetitions=True) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['a', 'a', 'a', 'a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]──────────── ]──────────── + [ [ a: ═══@═══^═══ ](loops=2) ](loops=2) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───M───X───M───X───M───X───M───X─── + ║ ║ ║ ║ ║ ║ ║ ║ +a: ═══@═══^═══@═══^═══@═══^═══@═══^═══ +""", + use_unicode_characters=True, + ) + + def test_scope_extern(): q = cirq.LineQubit(0) inner = cirq.Circuit( From ce6cebecb78a21a615fdc37fb80c62da02c82d52 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 15 Feb 2022 15:03:15 -0800 Subject: [PATCH 2/8] format --- .../cirq/ops/classically_controlled_operation_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 39a80432103..89a2e2c00f6 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -488,8 +488,12 @@ def test_scope_flat(): cirq.measure(q, key='a'), cirq.X(q).with_classical_controls('a'), ) - middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True)) - outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2, flatten_repetitions=True) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True) + ) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, flatten_repetitions=True + ) circuit = outer_subcircuit.mapped_circuit(deep=True) internal_control_keys = [ str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) From d75f527cda2f95d94d14c93858761e3ffa85415d Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 15 Feb 2022 15:24:58 -0800 Subject: [PATCH 3/8] Add serialization logic and tests --- cirq-core/cirq/circuits/circuit_operation.py | 12 ++- .../cirq/circuits/circuit_operation_test.py | 20 +++++ .../classically_controlled_operation_test.py | 86 ++++++++++++++++++- .../json_test_data/CircuitOperation.json | 3 +- .../json_test_data/CircuitOperation.repr | 3 +- 5 files changed, 116 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index e702036a168..97dcf118b55 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -297,6 +297,8 @@ def __repr__(self): if self.repetition_ids != self._default_repetition_ids(): # Default repetition_ids need not be specified. args += f'repetition_ids={proper_repr(self.repetition_ids)},\n' + if self.flatten_repetitions: + args += 'flatten_repetitions=True,\n' indented_args = args.replace('\n', '\n ') return f'cirq.CircuitOperation({indented_args[:-4]})' @@ -327,6 +329,8 @@ def dict_str(d: Dict) -> str: elif self.repetitions != 1: # Only add loops if we haven't added repetition_ids. args.append(f'loops={self.repetitions}') + if self.flatten_repetitions: + args.append('flat') if not args: return circuit_msg return f'{circuit_msg}({", ".join(args)})' @@ -352,7 +356,7 @@ def __hash__(self): return self._hash def _json_dict_(self): - return { + resp = { 'circuit': self.circuit, 'repetitions': self.repetitions, # JSON requires mappings to have keys of basic types. @@ -363,6 +367,9 @@ def _json_dict_(self): 'repetition_ids': self.repetition_ids, 'parent_path': self.parent_path, } + if self.flatten_repetitions: + resp['flatten_repetitions'] = True + return resp @classmethod def _from_json_dict_( @@ -374,10 +381,11 @@ def _from_json_dict_( param_resolver, repetition_ids, parent_path=(), + flatten_repetitions=False, **kwargs, ): return ( - cls(circuit) + cls(circuit, flatten_repetitions=flatten_repetitions) .with_qubit_mapping(dict(qubit_map)) .with_measurement_key_mapping(measurement_key_map) .with_params(param_resolver) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 7fbf782ef9b..acf67c76175 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -457,6 +457,26 @@ def test_string_format(): ]), )""" ) + op6 = cirq.CircuitOperation(fc5, flatten_repetitions=True) + assert ( + repr(op6) + == """\ +cirq.CircuitOperation( + circuit=cirq.FrozenCircuit([ + cirq.Moment( + cirq.X(cirq.LineQubit(0)), + cirq.CircuitOperation( + circuit=cirq.FrozenCircuit([ + cirq.Moment( + cirq.X(cirq.LineQubit(1)), + ), + ]), + ), + ), + ]), + flatten_repetitions=True, +)""" + ) def test_json_dict(): diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 89a2e2c00f6..81c6e781fe4 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -482,7 +482,7 @@ def test_scope_local(): assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) -def test_scope_flat(): +def test_scope_flatten_both(): q = cirq.LineQubit(0) inner = cirq.Circuit( cirq.measure(q, key='a'), @@ -504,9 +504,9 @@ def test_scope_flat(): cirq.testing.assert_has_diagram( cirq.Circuit(outer_subcircuit), """ - [ [ 0: ───M───X─── ] ] -0: ───[ 0: ───[ ║ ║ ]──────────── ]──────────── - [ [ a: ═══@═══^═══ ](loops=2) ](loops=2) + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]────────────────── ]────────────────── + [ [ a: ═══@═══^═══ ](loops=2, flat) ](loops=2, flat) """, use_unicode_characters=True, ) @@ -521,6 +521,84 @@ def test_scope_flat(): ) +def test_scope_flatten_inner(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('a'), + ) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True) + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:a', '0:a', '1:a', '1:a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]────────────────── ]──────────── + [ [ a: ═══@═══^═══ ](loops=2, flat) ](loops=2) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ─────M───X───M───X───M───X───M───X─── + ║ ║ ║ ║ ║ ║ ║ ║ +0:a: ═══@═══^═══@═══^═══╬═══╬═══╬═══╬═══ + ║ ║ ║ ║ +1:a: ═══════════════════@═══^═══@═══^═══ +""", + use_unicode_characters=True, + ) + + +def test_scope_flatten_outer(): + q = cirq.LineQubit(0) + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls('a'), + ) + middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, flatten_repetitions=True + ) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_control_keys = [ + str(condition) for op in circuit.all_operations() for condition in cirq.control_keys(op) + ] + assert internal_control_keys == ['0:a', '1:a', '0:a', '1:a'] + assert not cirq.control_keys(outer_subcircuit) + assert not cirq.control_keys(circuit) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]──────────── ]────────────────── + [ [ a: ═══@═══^═══ ](loops=2) ](loops=2, flat) +""", + use_unicode_characters=True, + ) + cirq.testing.assert_has_diagram( + circuit, + """ +0: ─────M───X───M───X───M───X───M───X─── + ║ ║ ║ ║ ║ ║ ║ ║ +0:a: ═══@═══^═══╬═══╬═══@═══^═══╬═══╬═══ + ║ ║ ║ ║ +1:a: ═══════════@═══^═══════════@═══^═══ +""", + use_unicode_characters=True, + ) + + def test_scope_extern(): q = cirq.LineQubit(0) inner = cirq.Circuit( diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json index 73b77c264db..8db21084719 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json @@ -292,7 +292,8 @@ ] }, "parent_path": [], - "repetition_ids": null + "repetition_ids": null, + "flatten_repetitions": true } ] ] diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr index 05010e33fa3..a524a9ce0a4 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr @@ -34,4 +34,5 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ (cirq.X**sympy.Symbol('theta')).on(cirq.LineQubit(0)), ), ]), -param_resolver={sympy.Symbol('theta'): 1.5})] \ No newline at end of file +param_resolver={sympy.Symbol('theta'): 1.5}, +flatten_repetitions=True)] \ No newline at end of file From d7fe99a04df57bffcaf3d951e3e2d4b51c2085c9 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 15 Feb 2022 15:33:03 -0800 Subject: [PATCH 4/8] Change flatten_repetitions (default False) to use_repetition_ids (default True) --- cirq-core/cirq/circuits/circuit_operation.py | 24 +++++++++---------- .../cirq/circuits/circuit_operation_test.py | 4 ++-- .../classically_controlled_operation_test.py | 8 +++---- .../json_test_data/CircuitOperation.json | 2 +- .../json_test_data/CircuitOperation.repr | 2 +- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 97dcf118b55..a94041f6513 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -108,7 +108,7 @@ class CircuitOperation(ops.Operation): repetition_ids: Optional[List[str]] = dataclasses.field(default=None) parent_path: Tuple[str, ...] = dataclasses.field(default_factory=tuple) extern_keys: FrozenSet['cirq.MeasurementKey'] = dataclasses.field(default_factory=frozenset) - flatten_repetitions: bool = False + use_repetition_ids: bool = True def __post_init__(self): if not isinstance(self.circuit, circuits.FrozenCircuit): @@ -169,7 +169,7 @@ def __eq__(self, other) -> bool: and self.repetitions == other.repetitions and self.repetition_ids == other.repetition_ids and self.parent_path == other.parent_path - and self.flatten_repetitions == other.flatten_repetitions + and self.use_repetition_ids == other.use_repetition_ids ) # Methods for getting post-mapping properties of the contained circuit. @@ -192,7 +192,7 @@ def _is_measurement_(self) -> bool: def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: if self._cached_measurement_key_objs is None: circuit_keys = protocols.measurement_key_objs(self.circuit) - if self.repetition_ids is not None and not self.flatten_repetitions: + if self.repetition_ids is not None and self.use_repetition_ids: circuit_keys = { key.with_key_path_prefix(repetition_id) for repetition_id in self.repetition_ids @@ -253,7 +253,7 @@ def mapped_circuit(self, deep: bool = False) -> 'cirq.Circuit': if self.param_resolver: circuit = protocols.resolve_parameters(circuit, self.param_resolver, recursive=False) if self.repetition_ids: - if self.flatten_repetitions or not protocols.is_measurement(circuit): + if not self.use_repetition_ids or not protocols.is_measurement(circuit): circuit = circuit * abs(self.repetitions) else: circuit = circuits.Circuit( @@ -297,8 +297,8 @@ def __repr__(self): if self.repetition_ids != self._default_repetition_ids(): # Default repetition_ids need not be specified. args += f'repetition_ids={proper_repr(self.repetition_ids)},\n' - if self.flatten_repetitions: - args += 'flatten_repetitions=True,\n' + if not self.use_repetition_ids: + args += 'use_repetition_ids=False,\n' indented_args = args.replace('\n', '\n ') return f'cirq.CircuitOperation({indented_args[:-4]})' @@ -329,7 +329,7 @@ def dict_str(d: Dict) -> str: elif self.repetitions != 1: # Only add loops if we haven't added repetition_ids. args.append(f'loops={self.repetitions}') - if self.flatten_repetitions: + if not self.use_repetition_ids: args.append('flat') if not args: return circuit_msg @@ -349,7 +349,7 @@ def __hash__(self): self.param_resolver, self.parent_path, tuple([] if self.repetition_ids is None else self.repetition_ids), - self.flatten_repetitions, + self.use_repetition_ids, ) ), ) @@ -367,8 +367,8 @@ def _json_dict_(self): 'repetition_ids': self.repetition_ids, 'parent_path': self.parent_path, } - if self.flatten_repetitions: - resp['flatten_repetitions'] = True + if not self.use_repetition_ids: + resp['use_repetition_ids'] = False return resp @classmethod @@ -381,11 +381,11 @@ def _from_json_dict_( param_resolver, repetition_ids, parent_path=(), - flatten_repetitions=False, + use_repetition_ids=True, **kwargs, ): return ( - cls(circuit, flatten_repetitions=flatten_repetitions) + cls(circuit, use_repetition_ids=use_repetition_ids) .with_qubit_mapping(dict(qubit_map)) .with_measurement_key_mapping(measurement_key_map) .with_params(param_resolver) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index acf67c76175..1a77db5ef8e 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -457,7 +457,7 @@ def test_string_format(): ]), )""" ) - op6 = cirq.CircuitOperation(fc5, flatten_repetitions=True) + op6 = cirq.CircuitOperation(fc5, use_repetition_ids=False) assert ( repr(op6) == """\ @@ -474,7 +474,7 @@ def test_string_format(): ), ), ]), - flatten_repetitions=True, + use_repetition_ids=False, )""" ) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 81c6e781fe4..0ec09229892 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -489,10 +489,10 @@ def test_scope_flatten_both(): cirq.X(q).with_classical_controls('a'), ) middle = cirq.Circuit( - cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True) + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) ) outer_subcircuit = cirq.CircuitOperation( - middle.freeze(), repetitions=2, flatten_repetitions=True + middle.freeze(), repetitions=2, use_repetition_ids=False ) circuit = outer_subcircuit.mapped_circuit(deep=True) internal_control_keys = [ @@ -528,7 +528,7 @@ def test_scope_flatten_inner(): cirq.X(q).with_classical_controls('a'), ) middle = cirq.Circuit( - cirq.CircuitOperation(inner.freeze(), repetitions=2, flatten_repetitions=True) + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) ) outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) circuit = outer_subcircuit.mapped_circuit(deep=True) @@ -568,7 +568,7 @@ def test_scope_flatten_outer(): ) middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) outer_subcircuit = cirq.CircuitOperation( - middle.freeze(), repetitions=2, flatten_repetitions=True + middle.freeze(), repetitions=2, use_repetition_ids=False ) circuit = outer_subcircuit.mapped_circuit(deep=True) internal_control_keys = [ diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json index 8db21084719..fe87b08ad67 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.json @@ -293,7 +293,7 @@ }, "parent_path": [], "repetition_ids": null, - "flatten_repetitions": true + "use_repetition_ids": false } ] ] diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr index a524a9ce0a4..268baa3f157 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr @@ -35,4 +35,4 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ ), ]), param_resolver={sympy.Symbol('theta'): 1.5}, -flatten_repetitions=True)] \ No newline at end of file +use_repetition_ids=False)] \ No newline at end of file From 6f5b6f084e759d6c51697b7df4f6009ed6ab8971 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 15 Feb 2022 15:50:51 -0800 Subject: [PATCH 5/8] Add shape tests for simulation results from flattened subcircuits --- .../cirq/circuits/circuit_operation_test.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index 1a77db5ef8e..e0f4d2cfe8c 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -20,6 +20,13 @@ from cirq.circuits.circuit_operation import _full_join_string_lists +ALL_SIMULATORS = ( + cirq.Simulator(), + cirq.DensityMatrixSimulator(), + cirq.CliffordSimulator(), +) + + def test_properties(): a, b, c = cirq.LineQubit.range(3) circuit = cirq.FrozenCircuit( @@ -878,4 +885,47 @@ def test_mapped_circuit_allows_repeated_keys(): ) +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_simulate_flattened_subcircuit_both_levels(sim): + q = cirq.LineQubit(0) + inner = cirq.Circuit(cirq.measure(q, key='a')) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) + ) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, use_repetition_ids=False + ) + circuit = cirq.Circuit(outer_subcircuit) + result = sim.run(circuit) + assert result.records['a'].shape == (1, 4, 1) + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_simulate_flattened_subcircuit_outer(sim): + q = cirq.LineQubit(0) + inner = cirq.Circuit(cirq.measure(q, key='a')) + middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) + outer_subcircuit = cirq.CircuitOperation( + middle.freeze(), repetitions=2, use_repetition_ids=False + ) + circuit = cirq.Circuit(outer_subcircuit) + result = sim.run(circuit) + assert result.records['0:a'].shape == (1, 2, 1) + assert result.records['1:a'].shape == (1, 2, 1) + + +@pytest.mark.parametrize('sim', ALL_SIMULATORS) +def test_simulate_flattened_subcircuit_inner(sim): + q = cirq.LineQubit(0) + inner = cirq.Circuit(cirq.measure(q, key='a')) + middle = cirq.Circuit( + cirq.CircuitOperation(inner.freeze(), repetitions=2, use_repetition_ids=False) + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = cirq.Circuit(outer_subcircuit) + result = sim.run(circuit) + assert result.records['0:a'].shape == (1, 2, 1) + assert result.records['1:a'].shape == (1, 2, 1) + + # TODO: Operation has a "gate" property. What is this for a CircuitOperation? From 9e5c248518d3935efdecb08f089c637ef9ecc847 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 15 Feb 2022 16:15:24 -0800 Subject: [PATCH 6/8] docs --- cirq-core/cirq/circuits/circuit_operation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index a94041f6513..5df079aaa5d 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -90,6 +90,10 @@ class CircuitOperation(ops.Operation): targets for unbound `ClassicallyControlledOperation` keys. This field is not intended to be set or changed manually, and should be empty in circuits that aren't in the middle of decomposition. + use_repetition_ids: When True, any measurement key in the subcircuit + will have its path prepended with the repetition id for each + repetition. When False, this will not happen and the measurement + key will be repeated. """ _hash: Optional[int] = dataclasses.field(default=None, init=False) From d3ce8f2fcbc5675f0968bcf4a217ff6e17c8ef81 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 23 Feb 2022 10:49:10 -0800 Subject: [PATCH 7/8] address PR comments --- cirq-core/cirq/circuits/circuit_operation.py | 2 +- .../cirq/circuits/circuit_operation_test.py | 6 +++--- .../ops/classically_controlled_operation_test.py | 16 ++++++++-------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 5df079aaa5d..534e45e88ea 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -334,7 +334,7 @@ def dict_str(d: Dict) -> str: # Only add loops if we haven't added repetition_ids. args.append(f'loops={self.repetitions}') if not self.use_repetition_ids: - args.append('flat') + args.append('no_rep_ids') if not args: return circuit_msg return f'{circuit_msg}({", ".join(args)})' diff --git a/cirq-core/cirq/circuits/circuit_operation_test.py b/cirq-core/cirq/circuits/circuit_operation_test.py index e0f4d2cfe8c..087d02f4316 100644 --- a/cirq-core/cirq/circuits/circuit_operation_test.py +++ b/cirq-core/cirq/circuits/circuit_operation_test.py @@ -886,7 +886,7 @@ def test_mapped_circuit_allows_repeated_keys(): @pytest.mark.parametrize('sim', ALL_SIMULATORS) -def test_simulate_flattened_subcircuit_both_levels(sim): +def test_simulate_no_repetition_ids_both_levels(sim): q = cirq.LineQubit(0) inner = cirq.Circuit(cirq.measure(q, key='a')) middle = cirq.Circuit( @@ -901,7 +901,7 @@ def test_simulate_flattened_subcircuit_both_levels(sim): @pytest.mark.parametrize('sim', ALL_SIMULATORS) -def test_simulate_flattened_subcircuit_outer(sim): +def test_simulate_no_repetition_ids_outer(sim): q = cirq.LineQubit(0) inner = cirq.Circuit(cirq.measure(q, key='a')) middle = cirq.Circuit(cirq.CircuitOperation(inner.freeze(), repetitions=2)) @@ -915,7 +915,7 @@ def test_simulate_flattened_subcircuit_outer(sim): @pytest.mark.parametrize('sim', ALL_SIMULATORS) -def test_simulate_flattened_subcircuit_inner(sim): +def test_simulate_no_repetition_ids_inner(sim): q = cirq.LineQubit(0) inner = cirq.Circuit(cirq.measure(q, key='a')) middle = cirq.Circuit( diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index 0ec09229892..fc09d939eee 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -504,9 +504,9 @@ def test_scope_flatten_both(): cirq.testing.assert_has_diagram( cirq.Circuit(outer_subcircuit), """ - [ [ 0: ───M───X─── ] ] -0: ───[ 0: ───[ ║ ║ ]────────────────── ]────────────────── - [ [ a: ═══@═══^═══ ](loops=2, flat) ](loops=2, flat) + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]──────────────────────── ]──────────────────────── + [ [ a: ═══@═══^═══ ](loops=2, no_rep_ids) ](loops=2, no_rep_ids) """, use_unicode_characters=True, ) @@ -541,9 +541,9 @@ def test_scope_flatten_inner(): cirq.testing.assert_has_diagram( cirq.Circuit(outer_subcircuit), """ - [ [ 0: ───M───X─── ] ] -0: ───[ 0: ───[ ║ ║ ]────────────────── ]──────────── - [ [ a: ═══@═══^═══ ](loops=2, flat) ](loops=2) + [ [ 0: ───M───X─── ] ] +0: ───[ 0: ───[ ║ ║ ]──────────────────────── ]──────────── + [ [ a: ═══@═══^═══ ](loops=2, no_rep_ids) ](loops=2) """, use_unicode_characters=True, ) @@ -581,8 +581,8 @@ def test_scope_flatten_outer(): cirq.Circuit(outer_subcircuit), """ [ [ 0: ───M───X─── ] ] -0: ───[ 0: ───[ ║ ║ ]──────────── ]────────────────── - [ [ a: ═══@═══^═══ ](loops=2) ](loops=2, flat) +0: ───[ 0: ───[ ║ ║ ]──────────── ]──────────────────────── + [ [ a: ═══@═══^═══ ](loops=2) ](loops=2, no_rep_ids) """, use_unicode_characters=True, ) From 564e433656df29c17352c30679e8da15c46dc6b2 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 23 Feb 2022 10:53:24 -0800 Subject: [PATCH 8/8] Add json test for use_repetition_ids --- cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr index 268baa3f157..791ebda07e8 100644 --- a/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/CircuitOperation.repr @@ -28,7 +28,8 @@ cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ repetitions=-2, parent_path=('outer', 'inner'), repetition_ids=['a', 'b'], -qubit_map={cirq.LineQubit(0): cirq.LineQubit(1)}), +qubit_map={cirq.LineQubit(0): cirq.LineQubit(1)}, +use_repetition_ids=True), cirq.CircuitOperation(circuit=cirq.FrozenCircuit([ cirq.Moment( (cirq.X**sympy.Symbol('theta')).on(cirq.LineQubit(0)),